[Pallas] Upstream pallas to JAX

PiperOrigin-RevId: 552963029
This commit is contained in:
Sharad Vikram 2023-08-01 16:42:26 -07:00 committed by jax authors
parent 69cd3ebe99
commit d872812a35
35 changed files with 7730 additions and 3 deletions

View File

@ -159,7 +159,7 @@ jobs:
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic/__init__.py --ignore=jax/experimental/mosaic/dialects.py
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
documentation_render:

View File

@ -2,6 +2,7 @@ absl-py
build
cloudpickle
colorama>=0.4.4
hypothesis
numpy>=1.22
pillow>=9.1.0
portpicker

View File

@ -22,6 +22,8 @@ load(
"jax_test_util_visibility",
"jax_visibility",
"mosaic_internal_users",
"pallas_gpu_internal_users",
"pallas_tpu_internal_users",
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
@ -67,6 +69,20 @@ package_group(
] + mosaic_internal_users,
)
package_group(
name = "pallas_gpu_users",
packages = [
"//...",
] + pallas_gpu_internal_users,
)
package_group(
name = "pallas_tpu_users",
packages = [
"//...",
] + pallas_tpu_internal_users,
)
# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
@ -430,6 +446,52 @@ pytype_strict_library(
] + py_deps("numpy"),
)
pytype_strict_library(
name = "pallas",
srcs = glob(
[
"experimental/pallas/**/*.py",
],
exclude = [
"experimental/pallas/gpu.py",
"experimental/pallas/tpu.py",
],
),
visibility = [
"//visibility:public",
],
deps = [
":jax",
"//jax/_src/pallas",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "pallas_tpu",
srcs = ["experimental/pallas/tpu.py"],
visibility = [
":pallas_tpu_users",
],
deps = [
":pallas",
"//jax/_src/pallas/mosaic",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:primitives",
],
)
pytype_strict_library(
name = "pallas_gpu",
srcs = ["experimental/pallas/gpu.py"],
visibility = [
":pallas_gpu_users",
],
deps = [
":pallas",
"//jax/_src/pallas/triton",
],
)
pytype_strict_library(
name = "partial_eval",
srcs = ["_src/interpreters/partial_eval.py"],

View File

@ -40,15 +40,20 @@ py_library_providing_imports_info(
"//jaxlib",
"//jaxlib:cpu_feature_guard",
"//jaxlib:utils",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:chlo_dialect",
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:math_dialect",
"//jaxlib/mlir:memref_dialect",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:ml_program_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",
# xla_client
],
"//conditions:default": [],

View File

@ -103,6 +103,7 @@ import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version

View File

@ -20,5 +20,14 @@ import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.ml_program as ml_program
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
from jax._src import lib
# TODO(sharadmv): remove guard when minimum jaxlib version is bumped
if lib.version >= (0, 4, 15):
import jaxlib.mlir.dialects.arith as arith
import jaxlib.mlir.dialects.math as math
import jaxlib.mlir.dialects.memref as memref
import jaxlib.mlir.dialects.scf as scf
import jaxlib.mlir.dialects.vector as vector
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
import jaxlib.mlir.dialects.stablehlo as hlo

64
jax/_src/pallas/BUILD Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"py_deps",
)
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
],
)
py_library(
name = "pallas",
srcs = glob(
include = ["**/*.py"],
exclude = [
"triton/*.py",
"mosaic/*.py",
],
),
deps = [
"//jax",
"//jax:ad_util",
"//jax:core",
"//jax:mlir",
"//jax:partial_eval",
"//jax:pretty_printer",
"//jax:tree_util",
"//jax:util",
"//jax/_src/lib",
] + py_deps("numpy"),
)
py_library(
name = "gpu",
visibility = [],
deps = [
":pallas",
"//jax/_src/pallas/triton",
],
)
py_library(
name = "tpu",
visibility = [],
deps = [
":pallas",
"//jax/_src/pallas/mosaic",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

229
jax/_src/pallas/core.py Normal file
View File

@ -0,0 +1,229 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for pallas-core functionality."""
from __future__ import annotations
from collections.abc import Sequence
import contextlib
import dataclasses
import functools
from typing import Any, Callable, Iterator
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
import jax.numpy as jnp
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
partial = functools.partial
Grid = tuple[int, ...]
split_list = util.split_list
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@dataclasses.dataclass
class GridEnv:
axis_index: Any
axis_size: int
_grid_env_stack: list[tuple[GridEnv, ...]] = []
@contextlib.contextmanager
def grid_env(env: tuple[tuple[Any, int], ...]) -> Iterator[None]:
_grid_env_stack.append(tuple(GridEnv(axis_index, axis_size)
for axis_index, axis_size in env))
try:
yield
finally:
_grid_env_stack.pop()
def current_grid_env() -> tuple[GridEnv, ...] | None:
if not _grid_env_stack:
return None
return _grid_env_stack[-1]
class Mapped:
pass
mapped = Mapped()
@dataclasses.dataclass(frozen=True)
class BlockSpec:
index_map: Callable[..., Any]
block_shape: tuple[int | None, ...]
def compute_index(self, *args):
out = self.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out
@dataclasses.dataclass(frozen=True)
class BlockMapping:
block_shape: tuple[Mapped | int, ...]
index_map_jaxpr: jax_core.ClosedJaxpr
def compute_start_indices(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
)
jaxpr = jax_core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
block_indices_and_rest = jax_core.jaxpr_as_fun(jaxpr)(*loop_idx, *args)
# Since we're passing in `Ref`s potentially, we need to split out their
# updated values since we only care about the return values.
block_indices, _ = split_list(block_indices_and_rest,
[len(self.block_shape)])
return tuple(i if b is mapped else b * i
for b, i in zip(self.block_shape, block_indices))
replace = dataclasses.replace
@dataclasses.dataclass(frozen=True)
class GridMapping:
grid: tuple[int, ...]
block_mappings: tuple[BlockMapping | None, ...]
mapped_dims: tuple[int, ...]
num_index_operands: int
replace = dataclasses.replace
def _preprocess_grid(grid: Grid | int | None) -> Grid:
if grid is None:
return ()
if isinstance(grid, int):
return (grid,)
return grid
def _convert_block_spec_to_block_mapping(
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
) -> BlockSpec | None:
if block_spec is _no_block_spec:
return None
block_shape = tuple(
mapped if s is None else s for s in block_spec.block_shape)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(block_spec.compute_index), in_avals)
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))
def _compute_shape_from_block_spec(block_spec: BlockSpec | None,
arg_shape: tuple[int, ...]
) -> tuple[int, ...]:
if block_spec is _no_block_spec:
return arg_shape
return tuple(s for s in block_spec.block_shape if s is not None)
def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
if grid is None:
in_specs = [None] * len(in_avals)
out_specs = [None] * len(out_avals)
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
for arg in in_avals]
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
for arg in out_avals]
else:
in_ref_avals = [
state.shaped_array_ref(
_compute_shape_from_block_spec(
block_spec, arg.shape), arg.dtype)
for block_spec, arg in zip(in_specs, in_avals)]
out_ref_avals = [
state.shaped_array_ref(
_compute_shape_from_block_spec(
block_spec, arg.shape), arg.dtype)
for block_spec, arg in zip(out_specs, out_avals)]
return in_specs, in_ref_avals, out_specs, out_ref_avals
_no_block_spec = object()
@dataclasses.dataclass(init=False)
class GridSpec:
grid: Grid
in_specs: Sequence[BlockSpec | None] | None
out_specs: tuple[BlockSpec | None, ...] | None
def __init__(
self,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | None] | None = None,
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
):
if grid is None:
if in_specs is not None:
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
if out_specs is not None:
raise ValueError("Cannot specify `out_specs` with a `None` grid.")
self.grid = _preprocess_grid(grid)
self.in_specs = in_specs
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
out_specs = (out_specs,)
if out_specs is not None and not isinstance(out_specs, tuple):
out_specs = tuple(out_specs)
self.out_specs = out_specs
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
if self.in_specs is not None:
in_specs = self.in_specs
in_spec_tree = tree_util.tree_structure(tuple(in_specs))
if in_spec_tree != in_tree:
raise ValueError(
"Pytree specs for arguments and `in_specs` must match: "
f"{in_tree} vs. {in_spec_tree}")
else:
in_specs = [_no_block_spec] * len(in_avals)
if self.out_specs is not None:
out_specs = self.out_specs
out_spec_tree = tree_util.tree_structure(out_specs)
if out_spec_tree != out_tree:
raise ValueError(
"Pytree specs for `out_shape` and `out_specs` must match: "
f"{out_tree} vs. {out_spec_tree}")
else:
out_specs = [_no_block_spec] * len(out_avals)
flat_in_specs = tree_util.tree_leaves(in_specs)
flat_out_specs = tree_util.tree_leaves(out_specs)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
self.grid, in_avals, flat_in_specs, out_avals,
flat_out_specs)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs)
grid_mapping = GridMapping(
self.grid, (*in_block_mappings, *out_block_mappings), (),
num_index_operands=0)
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping

158
jax/_src/pallas/indexing.py Normal file
View File

@ -0,0 +1,158 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains shared logic and abstractions for Pallas indexing ops."""
from __future__ import annotations
import dataclasses
from typing import Any, Tuple
import jax
from jax import core as jax_core
from jax import tree_util
from jax._src.interpreters import mlir
from jax._src.util import merge_lists
from jax._src.util import partition_list
import jax.numpy as jnp
import numpy as np
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = jax_core.Primitive('broadcast_to')
def broadcast_to(a: jax.Array, shape: Tuple[int, ...]) -> jax.Array:
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
return jnp.broadcast_to(a, shape)
@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return jax_core.ShapedArray(shape, aval.dtype)
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (True, self.start, self.size)
return (self.start,), (False, self.size)
@classmethod
def tree_unflatten(cls, data, xs):
is_static = data[0]
if is_static:
del xs
start, size = data[1:]
return Slice(start, size)
start, = xs
size = data[1]
return Slice(start, size)
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop = slc.start, slc.stop
start = 0 if start is None else start
stop = size if stop is None else stop
return Slice(start, stop - start)
def dslice(start: int | jax.Array | None, size: int | None = None
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
ds = dslice # Handy alias
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: Tuple[int | Slice | jax.Array, ...]
shape: Tuple[int, ...]
int_indexer_shape: Tuple[int, ...]
def __post_init__(self):
if len(self.indices) != len(self.shape):
raise ValueError("`indices` must be the same length as `Ref` shape.")
def tree_flatten(self):
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
self.int_indexer_shape)
@classmethod
def tree_unflatten(cls, data, flat_idx):
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)
@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if len(indices) > len(shape):
raise ValueError("`indices` must be the no longer than `shape`.")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [i.shape for i in int_indexers]
if indexer_shapes:
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)
def get_indexer_shape(self) -> Tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers] # type: ignore
return tuple((*self.int_indexer_shape, *other_shape))

View File

@ -0,0 +1,89 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Package for Mosaic-specific Pallas extensions
load(
"//jaxlib:jax.bzl",
"py_deps",
"py_library_providing_imports_info",
)
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
],
)
py_library_providing_imports_info(
name = "mosaic",
srcs = ["__init__.py"],
lib_rule = py_library,
deps = [
":core",
":pallas_call_registration",
":primitives",
],
)
py_library(
name = "core",
srcs = ["core.py"],
deps = [
"//jax",
"//jax/_src/pallas",
],
)
py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
"//jax",
"//jax:core",
"//jax:mlir",
],
)
py_library(
name = "pallas_call_registration",
srcs = ["pallas_call_registration.py"],
deps = [
":lowering",
"//jax",
"//jax:mosaic",
"//jax:source_info_util",
"//jax/_src/lib",
"//jax/_src/pallas",
] + py_deps("numpy"),
)
py_library(
name = "lowering",
srcs = ["lowering.py"],
deps = [
":core",
":primitives",
"//jax",
"//jax:core",
"//jax:mlir",
"//jax:mosaic",
"//jax:partial_eval",
"//jax:source_info_util",
"//jax:util",
"//jax:xla",
"//jax/_src/lib",
"//jax/_src/pallas",
] + py_deps("numpy"),
)

View File

@ -0,0 +1,28 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for Mosaic lowering of Pallas call."""
from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic import pallas_call_registration
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import trace
VMEM = TPUMemorySpace.VMEM
SMEM = TPUMemorySpace.SMEM
CMEM = TPUMemorySpace.CMEM
del pallas_call_registration

View File

@ -0,0 +1,104 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains TPU-specific Pallas abstractions."""
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import enum
import functools
from jax._src import core as jax_core
from jax._src import state
from jax._src import tree_util
from jax._src import util
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
partial = functools.partial
Grid = pallas_core.Grid
BlockSpec = pallas_core.BlockSpec
GridMapping = pallas_core.GridMapping
_preprocess_grid = pallas_core._preprocess_grid
_compute_shape_from_block_spec = pallas_core._compute_shape_from_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
split_list = util.split_list
class TPUMemorySpace(enum.Enum):
VMEM = "vmem"
SMEM = "smem"
CMEM = "cmem"
def __str__(self) -> str:
return self.value
@dataclasses.dataclass(init=False)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: Grid
num_scalar_prefetch: int
in_specs: Sequence[BlockSpec | None] | None
out_specs: tuple[BlockSpec | None, ...] | None
def __init__(
self,
num_scalar_prefetch: int,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | None] | None = None,
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
):
if grid is None:
raise NotImplementedError("Should pass in non-`None` grid.")
self.grid = _preprocess_grid(grid)
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
out_specs = (out_specs,)
if out_specs is not None and not isinstance(out_specs, tuple):
out_specs = tuple(out_specs)
self.num_scalar_prefetch = num_scalar_prefetch
self.in_specs = in_specs
self.out_specs = out_specs
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
scalar_avals, in_avals = split_list(in_avals, [self.num_scalar_prefetch])
in_specs, in_ref_avals, out_specs, out_ref_avals = (
pallas_core._get_ref_avals(
self.grid, in_avals, self.in_specs,
out_avals, self.out_specs))
scalar_ref_avals = [
state.shaped_array_ref(aval.shape, aval.dtype)
for aval in scalar_avals]
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals)), in_specs)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals)), out_specs)
grid_mapping = GridMapping(
grid=self.grid,
block_mappings=(*in_block_mappings, *out_block_mappings),
mapped_dims=(),
num_index_operands=self.num_scalar_prefetch,
)
jaxpr_in_avals = tree_util.tree_unflatten(
in_tree, [*scalar_ref_avals, *in_ref_avals])
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,76 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains registrations for pallas_call on TPU."""
from __future__ import annotations
from typing import Any
import jax
from jax import core as jax_core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
from jax.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core
from jax._src.pallas.mosaic import lowering
from jax._src.pallas.pallas_call import pallas_call_p
def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
which_linear: tuple[bool, ...],
grid_mapping: core.GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
debug: bool,
interpret: bool,
mosaic_params: dict[str, Any] | None = None,
**compiler_params: Any):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if input_output_aliases:
raise NotImplementedError(
"`input_output_aliases` not supported on TPU backend.")
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping, **compiler_params)
if debug:
print(jaxpr)
with ir.Context() as mlir_ctx, ir.Location.unknown():
tpu.register_dialect(mlir_ctx)
if mosaic_params is None:
mosaic_params = {}
dimension_semantics = mosaic_params.get("dimension_semantics", None)
mosaic_module = lowering.lower_jaxpr_to_module(
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics)
if debug:
print(mosaic_module)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
return mlir.lower_fun(
mosaic.as_tpu_kernel(
mosaic_module, out_avals, backend=ctx.module_context.backend
),
multiple_results=True,
)(ctx, *in_nodes)
mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule,
platform="tpu")

View File

@ -0,0 +1,66 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for Pallas:TPU-specific JAX primitives and functions."""
from __future__ import annotations
import contextlib
from jax._src import core as jax_core
from jax._src.interpreters import mlir
import jax.numpy as jnp
repeat_p = jax_core.Primitive('repeat')
def repeat(x, repeats, axis):
return repeat_p.bind(x, repeats=repeats, axis=axis)
@repeat_p.def_abstract_eval
def _repeat_abstract_eval(x, *, repeats, axis):
shape = list(x.shape)
shape[axis] *= repeats
return jax_core.ShapedArray(shape, x.dtype)
def _repeat_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, repeats, axis):
def _repeat(x):
return jnp.repeat(x, repeats, axis)
return mlir.lower_fun(_repeat, multiple_results=False)(ctx, x)
mlir.register_lowering(repeat_p, _repeat_lowering_rule)
trace_start_p = jax_core.Primitive('trace_start')
trace_start_p.multiple_results = True
@trace_start_p.def_abstract_eval
def _trace_start_abstract_eval(*, message: str, level: int):
del message, level
return []
trace_stop_p = jax_core.Primitive('trace_stop')
trace_stop_p.multiple_results = True
@trace_stop_p.def_abstract_eval
def _trace_stop_abstract_eval():
return []
@contextlib.contextmanager
def trace(message: str, level: int = 10):
trace_start_p.bind(message=message, level=level)
yield
trace_stop_p.bind()

View File

@ -0,0 +1,372 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for calling pallas functions from JAX."""
from __future__ import annotations
from functools import partial
import itertools as it
from typing import Any, Callable, Dict, Sequence, Tuple
import jax
from jax import api_util
from jax import linear_util as lu
from jax import tree_util
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src.state import discharge as state_discharge
from jax._src.util import (
split_list, safe_map, safe_zip, weakref_lru_cache,
tuple_insert, partition_list)
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
import numpy as np
from jax._src.pallas import core as pallas_core
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Grid = pallas_core.Grid
BlockSpec = pallas_core.BlockSpec
GridSpec = pallas_core.GridSpec
BlockMapping = pallas_core.BlockMapping
GridMapping = pallas_core.GridMapping
pallas_call_p = jax_core.Primitive('pallas_call')
pallas_call_p.multiple_results = True
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
if start_idx is None:
assert is_indexing is None
return value
assert is_indexing is not None
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims)
def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
is_indexing):
if start_idx is None:
assert is_indexing is None
return update
assert is_indexing is not None
broadcast_dims = tuple(i for i, b in enumerate(is_indexing)
if not b)
update = lax.broadcast_in_dim(update, block_shape, broadcast_dims)
assert update.shape == block_shape
return lax.dynamic_update_slice(value, update, start_idx)
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
interpret, debug: bool,
in_shapes,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any):
if interpret:
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr. This should reproduce exactly what compiling to Triton
# will do.
grid = grid_mapping.grid
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
loop_indices = jnp.array(list(it.product(*(range(g) for g in grid))))
oi_map = {v: k for k, v in input_output_aliases}
out = []
for i, out_shape in enumerate(out_shapes):
if i in oi_map:
out.append(args[oi_map[i]])
else:
out.append(jnp.zeros(out_shape.shape, out_shape.dtype))
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
carry = [*args, *out]
def cond(carry):
return carry[0] < loop_indices.shape[0]
def body(carry):
i, *carry = carry
loop_idx = loop_indices[i]
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
block_shapes_without_mapped_dims = [
None if block_mapping is None else block_mapping.block_shape
for block_mapping in grid_mapping.block_mappings
]
is_indexing_dim = [
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
for bm in block_shapes_without_mapped_dims
]
block_shapes = [
None if bm is None else tuple(1 if i else b for i, b in zip(iid, bm))
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
is_mapped_grid_dim = [
i in grid_mapping.mapped_dims for i in range(len(grid_mapping.grid))]
local_grid_env, _ = partition_list(is_mapped_grid_dim,
zip(loop_idx, grid_mapping.grid))
with pallas_core.grid_env(tuple(local_grid_env)):
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
*blocks)
blocks = blocks[grid_mapping.num_index_operands:]
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, *carry)
(_, *carry) = lax.while_loop(cond, body, (0, *carry))
_, out = split_list(carry, [len(args)])
return out
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
in_shapes=in_shapes,
out_shapes=out_shapes, which_linear=which_linear,
grid_mapping=grid_mapping, interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
**compiler_params)
pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
input_output_aliases: Tuple[Tuple[int, int], ...],
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
if grid_mapping.num_index_operands:
raise NotImplementedError
if input_output_aliases:
raise NotImplementedError("JVP with aliasing not supported.")
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
tangents = [ad.instantiate_zeros(t) if inst else t
for t, inst in zip(tangents, nonzero_tangents)]
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes)
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
jvp_which_linear = which_linear + (True,) * len(tangents)
jvp_inshapes = (*in_shapes, *in_shapes)
jvp_outshapes = (*out_shapes, *out_shapes)
if input_output_aliases:
raise NotImplementedError("`input_output_aliases` jvp not supported.")
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
# `Ref`s that are read from followed by output `Ref`s that are written to.
# This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new
# jaxpr that has tangents following primals. In order for this jaxpr to be
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
# the jaxpr's invars.
logical_primals, logical_tangents = split_list(
jvp_jaxpr.invars, [len(primals) + len(out_shapes)])
logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)])
logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)])
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms))
new_grid_mapping = grid_mapping.replace(block_mappings=new_bms)
jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs,
*logical_tangent_inputs,
*logical_primal_outputs,
*logical_tangent_outputs])
if debug:
print(jvp_jaxpr)
out_flat = pallas_call_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr,
name=f"{name}_jvp",
in_shapes=jvp_inshapes,
out_shapes=jvp_outshapes,
grid_mapping=new_grid_mapping,
which_linear=jvp_which_linear,
interpret=interpret,
debug=debug,
input_output_aliases=(),
**compiler_params)
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping | None) -> BlockMapping:
def _block_map_function(new_idx, *args):
if block_mapping is None:
indices = [0] * len(aval.shape)
else:
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
block_mapping.index_map_jaxpr.consts,
*args)
if dim is not batching.not_mapped:
indices.insert(dim, new_idx)
return tuple(indices)
i32_aval = jax_core.ShapedArray((), jnp.int32)
if block_mapping is None:
idx_avals = [i32_aval] * (len(grid) + 1)
else:
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
shape = aval.shape if block_mapping is None else block_mapping.block_shape
if dim is batching.not_mapped:
new_block_shape = shape
else:
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
if block_mapping is None:
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr)
return block_mapping.replace(block_shape=new_block_shape,
index_map_jaxpr=jaxpr)
def _pallas_call_batching_rule(args, dims, *,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: Tuple[Tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: Tuple[bool, ...],
**compiler_params: Any):
if grid_mapping.num_index_operands:
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
if any(bdim is not batching.not_mapped for bdim in scalar_batch_dims):
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
raise NotImplementedError
axis_size, = {x.shape[d] for x, d in zip(args, dims)
if d is not batching.not_mapped}
block_mappings = grid_mapping.block_mappings
avals = [v.aval for v in jaxpr.invars]
# How should we pick output dimensions? This actually matters because XLA
# can't optimize our pallas kernels, and this layout impacts performance. For
# now, because `vmap` doesn't really offer a way of inferring good output
# dimensions. For now, we just use 0.
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
# TODO(sharadmv): explore a long term solution to output dim inference
# When we have input/output aliasing, since the output will be mapped, we need
# to make sure to broadcast the input across that dimension if it is not
# mapped.
dims_ = list(dims)
args_ = list(args)
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
if dim is batching.not_mapped:
dims_[input_index] = 0
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
args = tuple(args_)
dims = tuple(dims_)
all_dims = list(dims) + [0] * len(out_shapes)
num_index_operands = grid_mapping.num_index_operands
batched_block_mappings = map(
partial(_batch_block_mapping, grid_mapping.grid),
avals[num_index_operands:], all_dims[num_index_operands:], block_mappings)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
x.dtype)
for x, dim in zip(in_shapes, dims))
batched_out_shapes = tuple(
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
for x in out_shapes)
batched_grid_mapping = grid_mapping.replace(
grid=(axis_size, *grid_mapping.grid),
block_mappings=tuple(batched_block_mappings),
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}",
in_shapes=batched_in_shapes,
out_shapes=batched_out_shapes,
which_linear=which_linear,
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
**compiler_params)
return out, (0,) * len(out)
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: str | None = None):
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree_thunk, False,
primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
jaxpr = for_loop._hoist_consts_to_refs(jaxpr)
return jaxpr, consts, out_tree_thunk()
def _extract_function_name(f: Callable, name: str | None) -> str:
if name is None:
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
return name
def pallas_call(
f: Callable[..., None], out_shape: Any, *,
grid_spec: GridSpec | None = None,
debug: bool = False,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | None] | None = None,
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
input_output_aliases: Dict[int, int] = {},
interpret: bool = False,
name: str | None = None,
**compiler_params: Any):
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)
name = _extract_function_name(f, name)
singleton = False
if not isinstance(out_shape, (tuple, list)):
out_shape = (out_shape,)
singleton = True
if not isinstance(out_shape, tuple):
out_shape = tuple(out_shape)
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
for x in flat_out_shapes]
@jax.jit
def wrapped(*args):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_avals = [jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args]
avals, grid_mapping = grid_spec.get_grid_mapping(flat_avals, in_tree,
flat_out_shapes, out_tree)
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
tuple(jaxpr_flat_avals),
primitive_name="pallas_call")
which_linear = (False,) * len(flat_args)
out_flat = pallas_call_p.bind(
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
out_shapes=tuple(flat_out_shapes), debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,
input_output_aliases=tuple(input_output_aliases.items()),
**compiler_params)
out = tree_util.tree_unflatten(out_tree, out_flat)
if singleton:
return out[0]
return out
return wrapped

View File

@ -0,0 +1,411 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for pallas-specific JAX primitives and functions."""
from __future__ import annotations
import enum
import functools
from typing import Any
import jax
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src.util import (safe_map, safe_zip)
from jax._src.state import primitives as state_primitives
from jax._src.state import discharge as state_discharge
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
partial = functools.partial
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
program_id_p = jax_core.Primitive("program_id")
def program_id(axis):
return program_id_p.bind(axis=axis)
def program_id_bind(*, axis: int):
grid_env = pallas_core.current_grid_env()
if grid_env:
return grid_env[axis].axis_index
return jax_core.Primitive.bind(program_id_p, axis=axis)
program_id_p.def_custom_bind(program_id_bind)
def _program_id_impl(*, axis: int):
grid_env = pallas_core.current_grid_env()
return grid_env[axis].axis_index
program_id_p.def_impl(_program_id_impl)
mlir.register_lowering(program_id_p, functools.partial(xla.apply_primitive,
program_id_p))
def _program_id_abstract_eval(**_):
return jax_core.ShapedArray((), jnp.int32)
program_id_p.def_abstract_eval(_program_id_abstract_eval)
class AtomicOpType(enum.Enum):
XCHG = "xchg"
ADD = "add"
MAX = "max"
MIN = "min"
AND = "and"
OR = "or"
XOR = "xor"
atomic_rmw_p = jax_core.Primitive("atomic_rmw")
def _atomic_rmw_discharge_rule(in_avals, out_avals, ref, val, *args, args_tree,
masked, atomic_type: AtomicOpType):
if masked: raise NotImplementedError
ref_aval, val_aval, *in_avals = in_avals
idx_aval, *_ = tree_util.tree_unflatten(args_tree, in_avals)
idx, *_ = tree_util.tree_unflatten(args_tree, args)
if atomic_type == AtomicOpType.ADD:
monoid = lambda x, y: x + y
elif atomic_type == AtomicOpType.MAX:
monoid = jnp.maximum
elif atomic_type == AtomicOpType.MIN:
monoid = jnp.minimum
else:
raise NotImplementedError(atomic_type)
if all(isinstance(s, Slice) or s.shape == () for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
val = val[val_indexer]
val = monoid(val, out_ones)
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
x_new = ref.at[idx.indices].set(monoid(out, val))
else:
raise NotImplementedError
return (x_new,) + (None,) * (len(in_avals) + 1), out
state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule)
def _atomic_abstract_eval(ref_aval, val_aval, *all_avals,
args_tree, atomic_type: AtomicOpType,
**_: Any):
if ref_aval.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.")
if ref_aval.dtype in {jnp.dtype("bool"), jnp.dtype("int8"),
jnp.dtype("int16"), jnp.bfloat16}:
raise ValueError(f"`atomic_{atomic_type.value}` does not support {ref_aval.dtype}.")
return _swap_abstract_eval(ref_aval, val_aval, *all_avals,
args_tree=args_tree)
atomic_rmw_p.def_effectful_abstract_eval(_atomic_abstract_eval)
def atomic_rmw(x_ref, idx, val, *, mask: Any | None = None,
atomic_type: AtomicOpType):
idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
args = (idx,)
if mask is not None:
args = (*args, mask)
flat_args, args_tree = tree_util.tree_flatten(args)
return atomic_rmw_p.bind(x_ref, val, *flat_args, args_tree=args_tree,
atomic_type=atomic_type, masked=mask is not None)
atomic_xchg = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XCHG)
atomic_add = functools.partial(atomic_rmw, atomic_type=AtomicOpType.ADD)
atomic_max = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MAX)
atomic_min = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MIN)
atomic_and = functools.partial(atomic_rmw, atomic_type=AtomicOpType.AND)
atomic_or = functools.partial(atomic_rmw, atomic_type=AtomicOpType.OR)
atomic_xor = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XOR)
atomic_cas_p = jax_core.Primitive("atomic_cas")
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
if cmp_aval.dtype != val_aval.dtype:
raise ValueError("Dtypes in cmp/val need to match")
if ref_aval.shape != ():
raise ValueError("Ref must be scalar.")
if cmp_aval.shape != ():
raise ValueError("Cmp must be scalar.")
if val_aval.shape != ():
raise ValueError("Val must be scalar.")
if cmp_aval.shape != val_aval.shape:
raise ValueError("Dtypes in cmp/val need to match")
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)}
atomic_cas_p.def_effectful_abstract_eval(_atomic_cas_abstract_eval)
def atomic_cas(ref, cmp, val):
return atomic_cas_p.bind(ref, cmp, val)
@state_discharge.register_discharge_rule(atomic_cas_p)
def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val):
del in_avals, out_avals
new_val = jnp.where(ref == cmp, val, ref)
return (new_val, None, None), ref
max_contiguous_p = jax_core.Primitive("max_contiguous")
max_contiguous_p.def_impl(lambda x, **_: x)
mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x])
def max_contiguous(x, values):
if not isinstance(values, list):
values = [values]
return max_contiguous_p.bind(x, values=values)
def _max_contiguous_abstract_eval(aval, **_):
return aval
max_contiguous_p.def_abstract_eval(_max_contiguous_abstract_eval)
multiple_of_p = jax_core.Primitive("multiple_of")
multiple_of_p.def_impl(lambda x, **_: x)
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
def multiple_of(x, values):
if not isinstance(values, list):
values = [values]
return multiple_of_p.bind(x, values=values)
def _multiple_of_abstract_eval(aval, **_):
return aval
multiple_of_p.def_abstract_eval(_multiple_of_abstract_eval)
load_p = jax_core.Primitive('masked_load')
def _load_abstract_eval(ref_aval, *all_avals, args_tree,
**params: Any):
idx_aval, *_ = tree_util.tree_unflatten(args_tree, all_avals)
return (jax_core.ShapedArray(idx_aval.get_indexer_shape(), ref_aval.dtype),
{state.ReadEffect(0)})
load_p.def_effectful_abstract_eval(_load_abstract_eval)
def _pp_dslice(dim: int, slice: Slice, context):
size = pp.text(str(slice.size))
if isinstance(slice.start, int):
if slice.start == 0:
start = pp.text("")
else:
start = pp.text(str(slice.start))
if slice.size == dim:
end = pp.text("")
else:
end = pp.text(str(slice.start + slice.size))
else:
start = pp.text(jax_core.pp_var(slice.start, context))
end = pp.concat([start, pp.text("+"), size])
return pp.concat([start, pp.text(":"), end])
def _pp_idx(ref_aval, idx: NDIndexer, context):
docs = [
_pp_dslice(d, s, context) if isinstance(s, Slice)
else pp.text(jax_core.pp_var(s, context))
for s, d in zip(idx.indices, ref_aval.shape)]
if not docs:
return pp.text("")
doc = [docs[0]]
for d in docs[1:]:
doc.append(pp.text(","))
doc.append(d)
return pp.concat(doc)
def _load_pp_rule(eqn, context, settings):
# Pretty prints `a = load x i` as `x[i] <- a`
y, = eqn.outvars
x, *args = eqn.invars
idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args)
idx = _pp_idx(eqn.invars[0].aval, idx, context)
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
]))])
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any):
ref_primal, *rest_primals = primals
ref_tangent, *rest_tangents = tangents
idx_primal, *masked_other_primals = tree_util.tree_unflatten(args_tree, rest_primals)
flat_idx_primals = tree_util.tree_leaves(idx_primal)
_, *masked_other_tangents = tree_util.tree_unflatten(args_tree, rest_tangents)
tangent_args = flat_idx_primals
if masked:
tangent_args = [*tangent_args, masked_other_primals[0]]
if len(masked_other_tangents) == 2:
_, other_tangent = masked_other_tangents
other_tangent = ad_util.instantiate(other_tangent)
tangent_args = [*tangent_args, other_tangent]
return (
load_p.bind(ref_primal, *rest_primals, args_tree=args_tree, masked=masked, **params),
load_p.bind(ref_tangent, *tangent_args, args_tree=args_tree,
masked=masked, **params))
ad.primitive_jvps[load_p] = _load_jvp
def _load_discharge_rule(in_avals, out_avals, ref, *args, args_tree,
masked, eviction_policy, cache_modifier, is_volatile):
idx, *masked_other = tree_util.tree_unflatten(args_tree, args)
if all(isinstance(s, Slice) or not s.shape for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
else:
raise NotImplementedError
if masked and len(masked_other) == 2:
mask, other = masked_other
out = jnp.where(mask, out, other)
return (None,) * len(in_avals), out
state_discharge.register_discharge_rule(load_p)(_load_discharge_rule)
swap_p = jax_core.Primitive('masked_swap')
def _swap_abstract_eval(ref_aval, val_aval, *all_avals, args_tree,
**_: Any):
idx_aval, *_ = tree_util.tree_unflatten(args_tree, all_avals)
expected_output_shape = idx_aval.get_indexer_shape()
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx_aval}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
return (jax_core.ShapedArray(expected_output_shape, ref_aval.dtype),
{state.WriteEffect(0)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _swap_pp_rule(eqn, context, settings):
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
# or:
# Pretty prints `_ = swap x v i` as `x[i] <- v`
y, = eqn.outvars
x, val, *args = eqn.invars
idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args)
idx = _pp_idx(eqn.invars[0].aval, idx, context)
x_i = pp.concat([pp.text(jax_core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')])
if isinstance(y, jax_core.DropVar):
return pp.concat([state_primitives.pp_ref(
x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i),
pp.text(' <- '), state_primitives.pp_ref(x_i),
pp.text(', '), pp.text(jax_core.pp_var(val, context))])
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _swap_jvp(primals, tangents, *, args_tree, masked, **params: Any):
ref_primal, val_primal, *rest_primals = primals
ref_tangent, val_tangent, *rest_tangents = tangents
val_tangent = ad_util.instantiate(val_tangent)
idx_primal, *masked_other_primals = tree_util.tree_unflatten(args_tree, rest_primals)
flat_idx_primals = tree_util.tree_leaves(idx_primal)
_, *masked_other_tangents = tree_util.tree_unflatten(args_tree, rest_tangents)
tangent_args = flat_idx_primals
if masked:
tangent_args = [*tangent_args, masked_other_primals[0]]
if len(masked_other_tangents) == 2:
_, other_tangent = masked_other_tangents
other_tangent = ad_util.instantiate(other_tangent)
tangent_args = [*tangent_args, other_tangent]
return (
swap_p.bind(ref_primal, val_primal, *rest_primals, args_tree=args_tree, masked=masked, **params),
swap_p.bind(ref_tangent, val_tangent, *tangent_args, args_tree=args_tree,
masked=masked, **params))
ad.primitive_jvps[swap_p] = _swap_jvp
def _swap_discharge_rule(in_avals, out_avals, ref, val, *args, args_tree,
masked, eviction_policy):
idx, *_ = tree_util.tree_unflatten(args_tree, args)
if all(isinstance(s, Slice) or s.shape == () for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
val = val[val_indexer]
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
x_new = ref.at[idx.indices].set(val)
else:
raise NotImplementedError
return (x_new,) + (None,) * (len(in_avals) - 1), out
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
def load(x_ref, idx, *, mask=None, other=None, cache_modifier="",
eviction_policy="", volatile=False):
idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
args = (idx,)
if mask is not None:
args = (*args, mask)
if other is not None:
assert mask is not None
args = (*args, other)
flat_args, args_tree = tree_util.tree_flatten(args)
return load_p.bind(x_ref, *flat_args, masked=mask is not None, cache_modifier=cache_modifier,
eviction_policy=eviction_policy, is_volatile=volatile,
args_tree=args_tree)
def swap(x_ref, idx, val, *, mask=None, eviction_policy="") -> Any:
idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
args = (idx,)
if mask is not None:
args = (*args, mask)
flat_args, args_tree = tree_util.tree_flatten(args)
return swap_p.bind(x_ref, val, *flat_args, masked=mask is not None,
eviction_policy=eviction_policy, args_tree=args_tree)
def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None:
_ = swap(x_ref, idx, val, mask=mask, eviction_policy=eviction_policy)
def dot(a, b, trans_a: bool = False, trans_b: bool = False,
allow_tf32: bool | None = None, precision=None):
lhs_contract_dim = 0 if trans_a else 1
rhs_contract_dim = 0 if not trans_b else 1
if allow_tf32 is not None:
if precision is not None:
raise ValueError("Only one of allow_tf32 and precision can be specified")
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
return jax.lax.dot_general(
a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=precision,
preferred_element_type=None).astype(jnp.float32)

View File

@ -0,0 +1,46 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Package for Triton-specific Pallas extensions
load(
"//jaxlib:jax.bzl",
"py_deps",
"py_library_providing_imports_info",
"pytype_strict_library",
)
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
],
)
py_library_providing_imports_info(
name = "triton",
srcs = ["__init__.py"],
lib_rule = pytype_strict_library,
deps = [
":lowering",
"//jax/_src/lib",
],
)
py_library(
name = "lowering",
srcs = ["lowering.py"],
deps = [
"//jax",
] + py_deps("jax_triton"),
)

View File

@ -0,0 +1,22 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Triton-specific pallas modules."""
from jax._src.pallas.triton import lowering
from jax._src.lib import gpu_triton as triton_kernel_call_lib
get_compute_capability = triton_kernel_call_lib.get_compute_capability
del lowering

File diff suppressed because it is too large Load Diff

49
jax/_src/pallas/utils.py Normal file
View File

@ -0,0 +1,49 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas utility functions."""
import math
import numpy as np
from typing import Tuple
from jax import lax
def when(condition):
def _wrapped(f):
if isinstance(condition, bool):
if condition:
f()
else:
lax.cond(condition, f, lambda: None)
return _wrapped
def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
def strides_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:
size = np.prod(shape)
strides = []
for s in shape:
size = size // s
strides.append(int(size))
return tuple(strides)
def next_power_of_2(x: int) -> int:
if x == 0:
return 1
return int(2 ** math.ceil(math.log2(x)))

View File

@ -15,6 +15,7 @@
"""JAX bindings for Mosaic."""
# mypy: ignore-errors
from __future__ import annotations
import base64
import collections.abc
@ -29,6 +30,7 @@ from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.config import config
from jax._src.lib import xla_client
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
from jaxlib.mlir.dialects import stablehlo
@ -38,7 +40,7 @@ import numpy as np
# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
if tpu_mosaic is None:
raise ValueError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
raise ImportError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
tpu = tpu_mosaic.tpu
apply_vector_layout = tpu_mosaic.apply_vector_layout
infer_memref_layout = tpu_mosaic.infer_memref_layout
@ -244,7 +246,7 @@ def as_tpu_kernel(
module: ir.Module,
out_type: Any,
*,
backend: str = "tpu",
backend: str | xla_client.Client = "tpu",
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
# We use jax.jit to make sure we hit the fast compilation cache.

View File

@ -0,0 +1,52 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for pallas, a JAX extension for custom kernels."""
from jax._src import pallas
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.indexing import ds
from jax._src.pallas.indexing import dslice
from jax._src.pallas.indexing import broadcast_to
from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add
from jax._src.pallas.primitives import atomic_and
from jax._src.pallas.primitives import atomic_cas
from jax._src.pallas.primitives import atomic_max
from jax._src.pallas.primitives import atomic_min
from jax._src.pallas.primitives import atomic_or
from jax._src.pallas.primitives import atomic_xchg
from jax._src.pallas.primitives import atomic_xor
from jax._src.pallas.primitives import dot
from jax._src.pallas.primitives import load
from jax._src.pallas.primitives import max_contiguous
from jax._src.pallas.primitives import multiple_of
from jax._src.pallas.primitives import program_id
from jax._src.pallas.primitives import store
from jax._src.pallas.primitives import swap
from jax._src.pallas.utils import cdiv
from jax._src.pallas.utils import next_power_of_2
from jax._src.pallas.utils import strides_from_shape
from jax._src.pallas.utils import when
try:
from jax.experimental.pallas import gpu # pytype: disable=import-error
except (ImportError, ModuleNotFoundError):
pass
try:
from jax.experimental.pallas import tpu # pytype: disable=import-error
except (ImportError, ModuleNotFoundError):
pass

View File

@ -0,0 +1,22 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Triton specific Pallas functions."""
try:
from jax._src.pallas import triton
get_compute_capability = triton.get_compute_capability
del triton
except ImportError as e:
raise ImportError("Cannot import Pallas Triton backend. "
"Make sure you've installed jax-triton.") from e

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,364 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing fused attention forward and backward pass."""
import functools
from typing import Any, Optional
import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental import pallas as pl
def mha_forward_kernel(
q_ref, k_ref, v_ref, # Input arrays
o_ref, # Output
*residual_refs, # Residual outputs
sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int):
seq_len = q_ref.shape[0]
start_q = pl.program_id(0)
# acc is the buffer where we accumulate the output on sram.
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
l_i = jnp.zeros(block_q, dtype=jnp.float32)
# acc is the buffer where we accumulate the output on sram.
acc = jnp.zeros((block_q, block_d), dtype=jnp.float32)
# Load q: it will stay in L1 throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_q, block_d], block_d == head_dim.
q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)))
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
# Here we only loop over blocks of kv to process entire seq_len, the loop over
# blocks of q is carried out by the grid.
def body(start_k, carry):
acc, m_prev, l_prev = carry
k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None)))
qk = jnp.zeros([block_q, block_k], dtype=jnp.float32)
qk += pl.dot(q, k.T) # [block_q, block_k]
if sm_scale != 1.:
qk *= sm_scale # [block_q, block_k]
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
span_k = start_k * block_k + jnp.arange(block_k)
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
# Bring closer to XLA:GPU numerics.
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev)
l_prev *= jnp.exp(m_prev - m_curr)
p = jnp.exp(qk - m_curr[:, None])
l_curr = jnp.sum(p, axis=1) + l_prev
l_rcp = 1. / l_curr
p = p * l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
p = p.astype(jnp.float16)
v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d)))
acc = acc + pl.dot(p.astype(v.dtype), v)
return acc, m_curr, l_curr
if causal:
upper_bound = lax.div(block_q * start_q, block_k) + 1
else:
upper_bound = pl.cdiv(seq_len, block_k) # type: ignore
acc, m_i, l_i = lax.fori_loop(0, upper_bound, body,
(acc, m_i, l_i))
if residual_refs:
l_ref, m_ref = residual_refs
pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i)
pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i)
# Write output to dram.
acc = acc.astype(o_ref.dtype)
pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
@functools.partial(jax.jit, static_argnames=["sm_scale", "causal", "block_q", "block_k",
"backward_pass_impl",
"num_warps", "num_stages", "grid",
"interpret", "debug"])
def mha(q, k, v,
sm_scale: float = 1.0,
causal: bool = False,
block_q: int = 128,
block_k: int = 128,
backward_pass_impl: str = "triton",
num_warps: Optional[int] = None,
num_stages: int = 2,
grid=None,
interpret: bool = False,
debug: bool = False):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4 if head_dim <= 64 else 8
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
block_q=block_q, block_k=block_k,
block_d=head_dim,
causal=causal)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
return pl.pallas_call(
kernel,
grid=grid_,
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
num_warps=num_warps_,
num_stages=num_stages,
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_forward")(q, k, v)
def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int,
block_k: int, backward_pass_impl: str, num_warps: Optional[int],
num_stages: int, grid: Any, interpret: bool, debug: bool):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4 if head_dim <= 64 else 8
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
causal=causal, block_q=block_q, block_k=block_k,
block_d=head_dim)
out_shape = [
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l
dtype=jnp.float32),
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m
dtype=jnp.float32)
]
out, l, m = pl.pallas_call(
kernel,
grid=grid_,
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
num_warps=num_warps_,
num_stages=num_stages,
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_forward")(q, k, v)
return out, (q, k, v, out, l, m)
def _preprocess_backward_kernel(out_ref, dout_ref, l_ref,
new_dout_ref, delta_ref, *,
block_q: int):
pid_m = pl.program_id(0)
off_m = pl.ds(pid_m * block_q, block_q)
# load
o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32)
do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32)
denom = pl.load(l_ref, (off_m,)).astype(jnp.float32)
# compute
do = do / denom[:, None]
delta = jnp.sum(o * do, axis=1)
# write-back
pl.store(new_dout_ref, (off_m, slice(None)),
do.astype(new_dout_ref.dtype))
pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype))
def _preprocess_backward(out, do, l, block_q: int,
debug: bool, interpret: bool):
batch_size, seq_len, num_heads, head_dim = out.shape
out_shape = [
jax.ShapeDtypeStruct(do.shape, do.dtype),
jax.ShapeDtypeStruct(l.shape, l.dtype),
]
do_scaled, delta = pl.pallas_call(
functools.partial(_preprocess_backward_kernel, block_q=block_q),
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
out_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
num_warps=4,
num_stages=3,
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_preprocess_backward")(out, do, l)
return do_scaled, delta
def mha_backward_kernel(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref, _,
# Outputs
dq_ref, dk_ref, dv_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]
def outer_loop(start_k, _):
dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)
def inner_loop(start_q, carry):
dv, dk = carry
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
if sm_scale != 1.0:
qk *= sm_scale
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
p = jnp.exp(qk - m[:, None])
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
dv = dv + pl.dot(p.astype(do.dtype).T, do)
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), eviction_policy="evict_last")
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
return dv, dk
if causal:
lower_bound = lax.div(start_k * block_k, block_q)
else:
lower_bound = 0
dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop,
(dv, dk))
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None)
def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
num_stages: int, grid: Any, interpret: bool,
debug: bool, res, do):
del num_warps, num_stages, grid
q, k, v, out, l, m = res
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
if backward_pass_impl == "xla":
return jax.vjp(functools.partial(mha_reference, sm_scale=sm_scale,
causal=causal), q, k, v)[1](do)
elif backward_pass_impl == "triton":
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)
out_shapes = [
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
grid = (batch_size, num_heads)
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
num_warps = 4
dq, dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid,
out_shape=out_shapes,
in_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv
mha.defvjp(_mha_forward, _mha_backward)
@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False):
q_seq_len = q.shape[1]
kv_seq_len = k.shape[1]
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
if causal:
mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool))
mask = jnp.broadcast_to(mask, logits.shape)
logits = jnp.where(mask, logits, float('-inf'))
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)

View File

@ -0,0 +1,281 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing fused layer norm forward and backward pass."""
import functools
from typing import Optional
import jax
from jax import lax
import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop
from jax.experimental import pallas as pl
def layer_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays
o_ref, mean_ref=None, rstd_ref=None, # Output arrays
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def mean_body(i, acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
acc_ref[:] += a
mean = for_loop(pl.cdiv(n_col, block_size), mean_body,
jnp.zeros(block_size)).sum() / n_col
def var_body(i, acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
a = jnp.where(mask, a - mean, 0.)
acc_ref[:] += a * a
var = for_loop(pl.cdiv(n_col, block_size), var_body,
jnp.zeros(block_size)).sum() / n_col
rstd = 1 / jnp.sqrt(var + eps)
if mean_ref is not None:
mean_ref[...] = mean.astype(mean_ref.dtype)
if rstd_ref is not None:
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
def body(i, _):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
weight = pl.load(weight_ref, (col_idx,), mask=mask)
bias = pl.load(bias_ref, (col_idx,), mask=mask)
x = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_first").astype(jnp.float32)
out = (x - mean) * rstd * weight + bias
pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask)
for_loop(pl.cdiv(n_col, block_size), body, ())
def layer_norm_forward(
x, weight, bias,
num_warps: Optional[int] = None,
num_stages: Optional[int] = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
del num_stages
del backward_pass_impl
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = [
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
]
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret, name='ln_forward')
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
out, mean, rstd = method(x, weight, bias)
return out, (x, weight, bias, mean, rstd)
def layer_norm_backward_kernel_dx(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
mean_ref, rstd_ref,
# Outputs
dx_ref,
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def mean_body(i, acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
a_hat = (a - mean_ref[...]) * rstd_ref[...]
wdout = weight * dout
mean1_acc_ref, mean2_acc_ref = acc_ref
mean1_acc_ref[:] += a_hat * wdout
mean2_acc_ref[:] += wdout
mean = for_loop(pl.cdiv(n_col, block_size), mean_body,
(jnp.zeros(block_size), jnp.zeros(block_size)))
mean1, mean2 = mean
mean1 = mean1.sum() / n_col
mean2 = mean2.sum() / n_col
def dx_body(i, acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
a_hat = (a - mean_ref[...]) * rstd_ref[...]
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...]
pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask)
for_loop(pl.cdiv(n_col, block_size), dx_body, ())
def layer_norm_backward_kernel_dw_db(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
mean_ref, rstd_ref,
# Outputs
dw_ref, db_ref,
*, eps: float, block_m: int, block_n: int):
m, n_col = x_ref.shape
j = pl.program_id(0)
col_idx = j * block_n + jnp.arange(block_n)
col_mask = col_idx < n_col
def body(i, acc_ref):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = pl.load(
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
dout = pl.load(
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw_acc_ref, db_acc_ref = acc_ref
dw_acc_ref[:] += (dout * a_hat).sum(axis=0)
db_acc_ref[:] += dout.sum(axis=0)
dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n)))
pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask)
pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask)
def layer_norm_backward(
num_warps: Optional[int],
num_stages: Optional[int],
eps: float,
backward_pass_impl: str,
interpret: bool,
res, do):
del num_stages
x, weight, bias, mean, rstd = res
if backward_pass_impl == 'xla':
return jax.vjp(layer_norm_reference, x, weight, bias)[1](do)
*shape_prefix, n = x.shape
reshaped_x = x.reshape((-1, n))
reshaped_mean = mean.reshape((-1,))
reshaped_rstd = rstd.reshape((-1,))
reshaped_do = do.reshape((-1, n))
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
# layer_norm_backward_kernel_dx parallel over batch dims
kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps,
block_size=block_size)
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape_dx, debug=False,
interpret=interpret, name='ln_backward_dx')
method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0))
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
dx = dx.reshape((*shape_prefix, n))
# layer_norm_backward_kernel_dw_db reduce over batch dims
# Triton heuristics
if n > 10240:
block_n = 128
block_m = 32
num_warps = 4
else:
# maximize occupancy for small N
block_n = 16
block_m = 16
num_warps = 8
kernel = functools.partial(layer_norm_backward_kernel_dw_db, eps=eps,
block_m=block_m, block_n=block_n)
out_shape_dwbias = [
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
]
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=grid_, out_shape=out_shape_dwbias, debug=False,
interpret=interpret, name='ln_backward_dw_db')
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
return dx, dw, dbias
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
"num_stages", "eps",
"backward_pass_impl",
"interpret"])
def layer_norm(
x, weight, bias,
num_warps: Optional[int] = None,
num_stages: Optional[int] = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
return method(x, weight, bias)
layer_norm.defvjp(layer_norm_forward, layer_norm_backward)
@functools.partial(jax.jit, static_argnames=["eps"])
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5):
mean = jnp.mean(x, axis=1)
mean2 = jnp.mean(jnp.square(x), axis=1)
var = jnp.maximum(0., mean2 - jnp.square(mean))
y = x - mean[:, None]
mul = lax.rsqrt(var + eps)
return y * mul[:, None] * weight[None] + bias[None]

View File

@ -0,0 +1,259 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing rms forward and backward pass."""
import functools
from typing import Optional
import jax
from jax import lax
import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop
from jax.experimental import pallas as pl
def rms_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays
o_ref, rstd_ref=None, # Output arrays
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def var_body(i, acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
a = jnp.where(mask, a, 0.)
acc_ref[:] += a * a
var = for_loop(pl.cdiv(n_col, block_size), var_body,
jnp.zeros(block_size)).sum() / n_col
rstd = 1 / jnp.sqrt(var + eps)
if rstd_ref is not None:
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
def body(i, _):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
weight = pl.load(weight_ref, (col_idx,), mask=mask)
bias = pl.load(bias_ref, (col_idx,), mask=mask)
x = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_first").astype(jnp.float32)
out = x * rstd * weight + bias
pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask)
for_loop(pl.cdiv(n_col, block_size), body, ())
def rms_norm_forward(
x, weight, bias,
num_warps: Optional[int] = None,
num_stages: Optional[int] = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
del num_stages
del backward_pass_impl
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = [
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
]
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret, name='rms_forward')
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
out, rstd = method(x, weight, bias)
return out, (x, weight, bias, rstd)
def rms_norm_backward_kernel_dx(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
rstd_ref,
# Outputs
dx_ref,
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def mean_body(i, c1_acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
a_hat = a * rstd_ref[...]
wdout = weight * dout
c1_acc_ref[:] += a_hat * wdout
c1 = for_loop(pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size))
c1 = c1.sum() / n_col
def dx_body(i, acc_ref):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
eviction_policy="evict_last").astype(jnp.float32)
a_hat = a * rstd_ref[...]
wdout = weight * dout
da = (wdout - (a_hat * c1)) * rstd_ref[...]
pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask)
for_loop(pl.cdiv(n_col, block_size), dx_body, ())
def rms_norm_backward_kernel_dw_db(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
rstd_ref,
# Outputs
dw_ref, db_ref,
*, eps: float, block_m: int, block_n: int):
m, n_col = x_ref.shape
j = pl.program_id(0)
col_idx = j * block_n + jnp.arange(block_n)
col_mask = col_idx < n_col
def body(i, acc_ref):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = pl.load(
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
dout = pl.load(
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
a_hat = a * rstd[:, None]
dw_acc_ref, db_acc_ref = acc_ref
dw_acc_ref[:] += (dout * a_hat).sum(axis=0)
db_acc_ref[:] += dout.sum(axis=0)
dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n)))
pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask)
pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask)
def rms_norm_backward(
num_warps: Optional[int],
num_stages: Optional[int],
eps: float,
backward_pass_impl: str,
interpret: bool,
res, do):
del num_stages
x, weight, bias, rstd = res
if backward_pass_impl == 'xla':
return jax.vjp(rms_norm_reference, x, weight, bias)[1](do)
*shape_prefix, n = x.shape
reshaped_x = x.reshape((-1, n))
reshaped_rstd = rstd.reshape((-1,))
reshaped_do = do.reshape((-1, n))
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
# rms_norm_backward_kernel_dx parallel over batch dims
kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps,
block_size=block_size)
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape_dx, debug=False,
interpret=interpret, name='ln_backward_dx')
method = jax.vmap(method, in_axes=(0, None, None, 0, 0))
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
dx = dx.reshape((*shape_prefix, n))
# rms_norm_backward_kernel_dw_db reduce over batch dims
# Triton heuristics
if n > 10240:
block_n = 128
block_m = 32
num_warps = 4
else:
# maximize occupancy for small N
block_n = 16
block_m = 16
num_warps = 8
kernel = functools.partial(rms_norm_backward_kernel_dw_db, eps=eps,
block_m=block_m, block_n=block_n)
out_shape_dwbias = [
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
]
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=grid_, out_shape=out_shape_dwbias, debug=False,
interpret=interpret, name='ln_backward_dw_db')
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
return dx, dw, dbias
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
"num_stages", "eps",
"backward_pass_impl",
"interpret"])
def rms_norm(
x, weight, bias,
num_warps: Optional[int] = None,
num_stages: Optional[int] = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
return method(x, weight, bias)
rms_norm.defvjp(rms_norm_forward, rms_norm_backward)
@functools.partial(jax.jit, static_argnames=["eps"])
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
def rms_norm_reference(x, weight, bias, *, eps: float = 1e-5):
var = jnp.mean(jnp.square(x), axis=1)
mul = lax.rsqrt(var + eps)
return x * mul[:, None] * weight[None] + bias[None]

View File

@ -0,0 +1,86 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas softmax kernel."""
import functools
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def _vmappable_softmax_kernel(
# inputs
input_ref,
# outputs
probs_ref,
*,
# block information
# It is assumed that block_row >= row_len
block_row: int,
):
row_len = input_ref.shape[-1]
mask = jnp.arange(block_row) < row_len
row = pl.load(
input_ref, (pl.dslice(0, block_row),), mask=mask, other=-float("inf")
)
row_max = jnp.max(row, axis=0)
numerator = jnp.exp((row - row_max).astype(jnp.float32))
denominator = jnp.sum(numerator, axis=0)
pl.store(
probs_ref, (pl.dslice(0, block_row),),
(numerator / denominator).astype(probs_ref.dtype),
mask=mask
)
@functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret",
"debug"])
def softmax(
x: jax.Array, *, axis: int = -1, num_warps: int = 4,
interpret: bool = False, debug: bool = False
) -> jax.Array:
"""Computes the softmax of the input array along the specified axis.
Args:
x: input array
axis: the axis along which to perform the computation
num_warps: the number of warps to use for executing the Triton kernel
interpret: whether to interpret the kernel using pallas
debug: whether to use pallas in debug mode
Returns:
The result of the softmax operation over the specified axis of x.
"""
axis = axis if axis >= 0 else len(x.shape) + axis
if axis != len(x.shape) - 1:
raise NotImplementedError(
"reductions along non-trailing dimension unsupported")
row_len = x.shape[-1]
block_row = pl.next_power_of_2(row_len)
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(),
out_shape=out_shape, debug=debug, interpret=interpret)
for _ in range(len(x.shape) - 1):
f = jax.vmap(f)
return f(x)

View File

@ -0,0 +1,22 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Mosaic specific Pallas functions."""
from jax._src.pallas.mosaic import CMEM
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
from jax._src.pallas.mosaic import SMEM
from jax._src.pallas.mosaic import TPUMemorySpace
from jax._src.pallas.mosaic import VMEM
from jax._src.pallas.mosaic import repeat
from jax._src.pallas.mosaic import trace

View File

@ -37,6 +37,9 @@ tf_cuda_tests_tags = _tf_cuda_tests_tags
jax_internal_packages = []
mosaic_internal_users = []
pallas_gpu_internal_users = []
pallas_tpu_internal_users = []
jax_test_util_visibility = []
loops_visibility = []

60
tests/pallas/BUILD Normal file
View File

@ -0,0 +1,60 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"jax_test",
"py_deps",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
jax_test(
name = "pallas_test",
srcs = [
"pallas_test.py",
],
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [
"gpu",
"gpu_a100",
"gpu_p100",
],
enable_configs = [
"gpu_x32",
"gpu_a100_x32",
],
shard_count = 4,
deps = [
"//third_party/py/jax:pallas",
] + py_deps("absl/testing") + py_deps("numpy"),
)
py_test(
name = "indexing_test",
srcs = [
"indexing_test.py",
],
deps = [
"//third_party/py/jax:pallas",
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)

View File

@ -0,0 +1,168 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Pallas indexing logic and abstractions."""
from typing import Union
from absl.testing import absltest
from absl.testing import parameterized
import hypothesis as hp
import hypothesis.extra.numpy as hnp
import hypothesis.strategies as hps
import jax
from jax._src import util
from jax._src.pallas import indexing
import numpy as np
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer
ds = indexing.ds
def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
return hps.integers(min_value=np.iinfo(np.int32).min, max_value=dim - 1)
@hps.composite
def slice_indexer_strategy(draw, dim) -> Union[Slice, slice]:
start = draw(int_indexer_strategy(dim))
size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max))
return draw(
hps.one_of(
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
)
)
@hps.composite
def array_indexer_strategy(draw, shape) -> jax.Array:
unbcast = [draw(hps.booleans()) for _ in shape]
shape = tuple(1 if unb else s for unb, s in zip(unbcast, shape))
return draw(hnp.arrays(dtype=np.dtype("int32"), shape=shape))
@hps.composite
def indexer_strategy(draw, dim, int_indexer_shape
) -> Union[int, Slice, jax.Array]:
return draw(hps.one_of(
int_indexer_strategy(dim),
slice_indexer_strategy(dim),
array_indexer_strategy(int_indexer_shape),
))
@hps.composite
def nd_indexer_strategy(draw, shape) -> NDIndexer:
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
int_indexer_shape = draw(hnp.array_shapes())
indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim
in shape[:num_indices]]
return NDIndexer.from_indices_shape(indices, shape)
class IndexerTest(parameterized.TestCase):
def test_simple_ndindexer(self):
indices = (0, 0)
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), ())
def test_invalid_ndindexer(self):
indices = (0, 0, 0)
shape = (5, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_ndindexer_with_padding(self):
indices = ()
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), shape)
def test_ndindexer_with_slices(self):
indices = (slice(2, 3), slice(4, 7))
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (1, 3))
def test_ndindexer_with_arrays(self):
indices = (np.arange(10), np.arange(10))
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (10,))
indices = (np.ones((10, 20)), np.ones((10, 20)))
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
def test_ndindexer_with_arrays_and_broadcasting(self):
indices = (np.arange(10)[None], np.arange(20)[:, None])
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (20, 10))
indices = (np.arange(10)[:, None], np.arange(20)[None, :])
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
def test_indexer_with_all_types(self):
indices = (0, slice(10), np.arange(5))
shape = (2, 3, 4)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 10))
indices = (0, slice(4, 10), np.arange(5))
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 6))
indices = (0, 5, np.arange(5))
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None])
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3))
@hp.given(hps.data())
def test_ndindexer(self, data):
shape = data.draw(hnp.array_shapes())
indexer = data.draw(nd_indexer_strategy(shape))
is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices]
rest_indexers, int_indexers = util.partition_list(
is_int_indexer, indexer.indices
)
if int_indexers:
expected_int_indexer_shape = int_indexers[0].shape
else:
expected_int_indexer_shape = ()
self.assertTupleEqual(
indexer.int_indexer_shape, expected_int_indexer_shape
)
for idx in rest_indexers:
self.assertIsInstance(idx, (np.ndarray, Slice))
if isinstance(idx, np.ndarray):
self.assertTupleEqual(idx.shape, ())
self.assertEqual(idx.dtype, np.dtype("int32"))
rest_shape = tuple(
r.size for r in rest_indexers if not isinstance(r, np.ndarray)
)
self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape),
indexer.get_indexer_shape())
if __name__ == "__main__":
absltest.main()

1598
tests/pallas/pallas_test.py Normal file

File diff suppressed because it is too large Load Diff