Merge pull request #294 from ROCm/ci-upstream-sync-151_1

CI: 03/18/25 upstream sync
This commit is contained in:
charleshofer 2025-03-18 11:10:16 -05:00 committed by GitHub
commit c46b4fc02b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
92 changed files with 1838 additions and 368 deletions

View File

@ -118,6 +118,11 @@ jobs:
run: |
$JAXCI_PYTHON -m pip install uv~=0.5.30
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then
$JAXCI_PYTHON -m uv pip install numpy~=2.1.0
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main

View File

@ -54,7 +54,8 @@ jobs:
runs-on: ${{ inputs.runner }}
# TODO: Update to the generic ML ecosystem test containers when they are ready.
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
env:

View File

@ -110,18 +110,30 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the artifact build jobs above
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
# See exlusions for what is fully tested
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
python: ["3.10",]
cuda: ["12.3", "12.1"]
cuda: ["12.1","12.3","12.8"]
enable-x64: [1, 0]
exclude:
# Run only a single configuration on H100 to save resources
# L4 does not run on cuda 12.8 but tests other configs
- runner: "linux-x86-g2-48-l4-4gpu"
cuda: "12.8"
# H100 runs only a single config, CUDA 12.3 Enable x64 1
- runner: "linux-x86-a3-8g-h100-8gpu"
cuda: "12.8"
- runner: "linux-x86-a3-8g-h100-8gpu"
python: "3.10"
cuda: "12.1"
- runner: "linux-x86-a3-8g-h100-8gpu"
python: "3.10"
enable-x64: 0
enable-x64: "0"
# B200 runs only a single config, CUDA 12.8 Enable x64 1
- runner: "linux-x86-a4-224-b200-1gpu"
enable-x64: "0"
- runner: "linux-x86-a4-224-b200-1gpu"
cuda: "12.1"
- runner: "linux-x86-a4-224-b200-1gpu"
cuda: "12.3"
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}

View File

@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
without replacement.
## jax 0.5.2 (Mar 4, 2025)

View File

@ -18,7 +18,4 @@ setuptools
matplotlib~=3.8.4; python_version=="3.10"
matplotlib; python_version>="3.11"
opt-einsum
auditwheel
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64"
auditwheel

View File

@ -49,13 +49,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hVi6mApuVw3r",
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
"id": "hVi6mApuVw3r"
},
"outputs": [],
"source": [
@ -84,13 +80,13 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mzDIDvj7Vw0k",
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
},
"outputs": [
{
@ -119,13 +115,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IyPx_-IBVwxr",
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
},
"outputs": [
{
@ -141,7 +137,7 @@
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
]
},
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -172,13 +168,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NO2ulM_QW7a8",
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
},
"outputs": [
{
@ -208,13 +204,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1-TzmA0AXCAf",
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
},
"outputs": [
{
@ -256,13 +252,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Gy7ABds3XND3",
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
},
"outputs": [
{
@ -297,13 +293,13 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "grCcotr-XQjY",
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
},
"outputs": [
{
@ -324,7 +320,7 @@
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
]
},
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -460,13 +456,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fpFEaMBcXsJG",
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
},
"outputs": [
{
@ -479,13 +475,6 @@
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
"Result type: ShapedArray(int32[4@X,4])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result type: ShapedArray(int32[4@X,4])\n"
]
}
],
"source": [
@ -550,13 +539,13 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "geptWrdYX0OM",
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
},
"outputs": [
{
@ -588,7 +577,88 @@
{
"cell_type": "markdown",
"metadata": {
"id": "AQQjzUeGX4P6"
"id": "LZWjgiMZ7uSS"
},
"source": [
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IVzPSkp77uCF",
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
"\n",
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
"\n",
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
"\n"
]
},
{
"data": {
"text/plain": [
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import functools\n",
"\n",
"@functools.partial(auto_axes, axes='X')\n",
"def g(y):\n",
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
" return y * 2\n",
"\n",
"@jax.jit\n",
"def f(arr1):\n",
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
" x = jnp.sin(arr1)\n",
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
"\n",
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
"\n",
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
" return z + 1\n",
"\n",
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
"f(some_x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3sfJjRq8w9f"
},
"source": [
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJcWbfAh7UcO"
},
"source": [
"## Concrete array shardings can mention `Auto` mesh axis\n",
@ -606,7 +676,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -708,5 +778,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 0
}

View File

@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you
find something that doesn't work!
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: hVi6mApuVw3r
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
---
:id: hVi6mApuVw3r
import jax
import numpy as np
import jax.numpy as jnp
@ -79,7 +75,7 @@ scalar) using `jax.typeof`:
colab:
base_uri: https://localhost:8080/
id: mzDIDvj7Vw0k
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a
---
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
@ -96,7 +92,7 @@ under a jit).
colab:
base_uri: https://localhost:8080/
id: IyPx_-IBVwxr
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb
---
@jax.jit
def foo(x):
@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins
colab:
base_uri: https://localhost:8080/
id: NO2ulM_QW7a8
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
outputId: d888371b-080e-4bff-be5d-ea56beda3aac
---
mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`:
colab:
base_uri: https://localhost:8080/
id: 1-TzmA0AXCAf
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2
---
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = reshard(replicated_array, P("X", None))
@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo
colab:
base_uri: https://localhost:8080/
id: Gy7ABds3XND3
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
outputId: 0d72dad2-381a-4e96-f771-40d705da1376
---
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
@ -184,7 +180,7 @@ We can do the same type querying under a jit:
colab:
base_uri: https://localhost:8080/
id: grCcotr-XQjY
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
outputId: c2db656c-809f-49a6-c948-629d6420360c
---
@jax.jit
def add_arrays(x, y):
@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows:
colab:
base_uri: https://localhost:8080/
id: fpFEaMBcXsJG
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef
---
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with
colab:
base_uri: https://localhost:8080/
id: geptWrdYX0OM
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f
---
print(f"Current mesh is: {get_abstract_mesh()}")
```
@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only
mention _explicit_ mesh axes and collective operations like `psum` can only
mention _manual_ mesh axes.
+++ {"id": "AQQjzUeGX4P6"}
+++ {"id": "LZWjgiMZ7uSS"}
You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: IVzPSkp77uCF
outputId: db80a604-98ac-4343-8677-23729adf7ffc
---
import functools
@functools.partial(auto_axes, axes='X')
def g(y):
print(f'mesh inside g: {get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
return y * 2
@jax.jit
def f(arr1):
print(f'mesh inside f: {get_abstract_mesh()}')
x = jnp.sin(arr1)
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
z = g(x, out_shardings=P("X", "Y"))
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
return z + 1
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
```
+++ {"id": "_3sfJjRq8w9f"}
As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.
+++ {"id": "sJcWbfAh7UcO"}
## Concrete array shardings can mention `Auto` mesh axis

View File

@ -299,7 +299,7 @@
" ):\n",
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
" del idxs_k_ref\n",
" blk_idx = pl.program_id(0)\n",
" blk_idx = pl.program_id(1)\n",
" is_start = blk_idx == 0\n",
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
" @pl.when(is_start | changed_blocks)\n",
@ -314,13 +314,13 @@
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
"\n",
"\n",
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
" del j, blk_idxs_i, blk_idxs_k\n",
" return (blk_idx, 0, 0)\n",
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
" del blk_idxs_i\n",
" return (blk_idxs_k[blk_idx], j)\n",
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
" del blk_idxs_k\n",
" return (blk_idxs_i[blk_idx], j)\n",
"\n",
@ -335,7 +335,7 @@
" num_scalar_prefetch=2,\n",
" # Note that while num_blocks is static here, Pallas does support\n",
" # dynamic grid sizes.\n",
" grid=(num_blocks, N // blk_N),\n",
" grid=(N // blk_N, num_blocks),\n",
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
" # Placeholder for a zeros-array used by input_output_aliases.\n",

View File

@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
):
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
del idxs_k_ref
blk_idx = pl.program_id(0)
blk_idx = pl.program_id(1)
is_start = blk_idx == 0
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
@pl.when(is_start | changed_blocks)
@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
del j, blk_idxs_i, blk_idxs_k
return (blk_idx, 0, 0)
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
del blk_idxs_i
return (blk_idxs_k[blk_idx], j)
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
del blk_idxs_k
return (blk_idxs_i[blk_idx], j)
@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
# Note that while num_blocks is static here, Pallas does support
# dynamic grid sizes.
grid=(num_blocks, N // blk_N),
grid=(N // blk_N, num_blocks),
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
pl.BlockSpec((blk_K, blk_N), y_map),
# Placeholder for a zeros-array used by input_output_aliases.

View File

@ -81,7 +81,7 @@ int main(int argc, char** argv) {
xla::XlaComputation xla_computation(test_module_proto);
xla::CompileOptions compile_options;
std::unique_ptr<xla::PjRtLoadedExecutable> executable =
client->Compile(xla_computation, compile_options).value();
client->CompileAndLoad(xla_computation, compile_options).value();
// Prepare inputs.
xla::Literal literal_x =

View File

@ -799,7 +799,7 @@ pytype_strict_library(
)
# This target only supports sm_90 GPUs.
py_library(
py_library_providing_imports_info(
name = "mosaic_gpu",
srcs = glob(["experimental/mosaic/gpu/*.py"]),
visibility = [
@ -824,6 +824,7 @@ py_library(
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic/python:gpu_dialect",
] + py_deps("absl/flags") + py_deps("numpy"),
)

View File

@ -67,7 +67,9 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.mesh import get_concrete_mesh
from jax._src.sharding_impls import (
PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding)
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
@ -2280,11 +2282,20 @@ def _check_sharding(aval, s):
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
s.shard_shape(aval.shape) # should raise an Error if incompatible
def pspec_to_sharding(val):
if isinstance(val, P):
mesh = get_concrete_mesh()
if mesh is None:
raise ValueError(
"Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is"
" passed to device_put")
return NamedSharding(mesh, val)
return val
def device_put(
x,
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
donate: bool | Any = False, may_alias: bool | None | Any = None):
"""Transfers ``x`` to ``device``.
@ -2333,6 +2344,9 @@ def device_put(
src_flat = flatten_axes("device_put source", treedef, src)
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
device_flat = map(pspec_to_sharding, device_flat)
src_flat = map(pspec_to_sharding, src_flat)
if isinstance(donate, bool):
donate_flat = [donate] * len(x_flat)
else:

View File

@ -28,17 +28,17 @@ class SampleFn(Protocol):
...
def _compute_scalar_index(iteration_index: Sequence[int],
total_size: Shape,
block_size: Shape,
block_index: Sequence[int]) -> int:
ndims = len(iteration_index)
def _compute_tile_index(block_index: Sequence[int],
total_size_in_blocks: Shape,
block_size_in_tiles: Shape,
tile_index_in_block: Sequence[int]) -> int:
ndims = len(block_index)
dim_size = 1
total_idx = 0
for i in range(ndims-1, -1, -1):
dim_idx = block_index[i] + iteration_index[i] * block_size[i]
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
total_idx += dim_idx * dim_size
dim_size *= total_size[i]
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
return total_idx
@ -99,18 +99,23 @@ def blocked_fold_in(
An N-dimensional nested list of keys required to sample the tiles
corresponding to the block specified by `block_index`.
"""
size_in_blocks = tuple(
_shape // _element for _shape, _element in zip(block_size, tile_size))
block_size_in_tiles = tuple(
_shape // _element for _shape, _element in zip(block_size, tile_size)
)
total_size_in_blocks = tuple(
_shape // _element for _shape, _element in zip(total_size, block_size)
)
def _keygen_loop(axis, prefix):
if axis == len(size_in_blocks):
if axis == len(block_size_in_tiles):
subtile_key = jax.random.fold_in(
global_key, _compute_scalar_index(
block_index, total_size, size_in_blocks, prefix))
global_key, _compute_tile_index(
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
return subtile_key
else:
keys = []
for i in range(size_in_blocks[axis]):
for i in range(block_size_in_tiles[axis]):
keys.append(_keygen_loop(axis+1, prefix+(i,)))
return keys
return _keygen_loop(0, tuple())

View File

@ -446,7 +446,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
if len(devices) == 1:
# If we only have one device in our computation, we can construct a
# replicated HloSharding and call it right now.
_hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding())
_hlo_sharding_callback(sharding_impls.replicated_hlo_sharding)
return []
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)

View File

@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy):
if not s.is_fully_addressable:
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
type(x) in array_types):
multihost_utils.assert_equal(
x, fail_message=(
f"{type(x)} passed to device_put is not the same on each"
" process. Make sure you are passing the same value of"
f" {type(x)} on each process."))
# TODO(emilyaf): Remove this condition when jit works when a sharding
# has no local devices.
if not config.enable_empty_arrays.value:
multihost_utils.assert_equal(
x, fail_message=(
f"{type(x)} passed to device_put is not the same on each"
" process. Make sure you are passing the same value of"
f" {type(x)} on each process."))
return _DeferredShardArg(x, s, aval, True, copy)
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
raise ValueError(

View File

@ -14,13 +14,17 @@
from __future__ import annotations
from functools import partial
import threading
import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
import jax._src.mesh as mesh_lib
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
Traceback = source_info_util.Traceback
@ -54,17 +58,61 @@ _error_storage = _ErrorStorage()
def _initialize_error_code_ref() -> None:
"""Initialize error_code_ref in the current thread."""
"""Initialize error_code_ref in the current thread.
The size of the error code array is determined by the mesh in the context. In
single-device environment, the array is a scalar. In multi-device
environment, the array has the same shape as the mesh.
"""
with core.eval_context():
error_code = jnp.uint32(_NO_ERROR)
# Get mesh from the context.
mesh = mesh_lib.get_concrete_mesh()
if mesh is None: # single-device case.
error_code = jnp.uint32(_NO_ERROR)
else: # multi-device case.
sharding = NamedSharding(mesh, P(*mesh.axis_names))
error_code = jnp.full(
mesh.axis_sizes,
jnp.uint32(_NO_ERROR),
device=sharding,
)
_error_storage.ref = core.mutable_array(error_code)
def set_error_if(pred: jax.Array, msg: str) -> None:
class error_checking_context:
"""Redefine the error checking state based on the mesh in the context.
This context manager should be used when starting a multi-device
computation, and whenever the mesh is changed.
When exiting the context, the error checking state will be reset to the
original state.
"""
__slots__ = ("old_ref",)
def __init__(self):
self.old_ref = None
def __enter__(self):
self.old_ref = _error_storage.ref
_initialize_error_code_ref()
return self
def __exit__(self, exc_type, exc_value, traceback):
_error_storage.ref = self.old_ref
def set_error_if(pred: jax.Array, /, msg: str) -> None:
"""Set error if any element of pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
In auto mode, this function does not work under jit.
"""
if _error_storage.ref is None:
_initialize_error_code_ref()
@ -76,7 +124,32 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
new_error_code = jnp.uint32(len(_error_list))
_error_list.append((msg, traceback))
pred = pred.any()
out_sharding = core.typeof(_error_storage.ref).sharding
in_sharding: NamedSharding = core.typeof(pred).sharding
if out_sharding.mesh.shape_tuple == (): # single-device case.
pred = pred.any()
else: # multi-device case.
has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types
if has_auto_axes:
raise NotImplementedError(
"Error checking in auto mode is not supported yet. Please use"
" explicit mode."
)
if out_sharding.mesh != in_sharding.mesh:
raise ValueError(
"The error code state and the predicate must be on the same mesh, "
f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. "
"Please use `with error_checking_context()` to redefine the error "
"code state based on the mesh."
)
pred = shard_map(
partial(jnp.any, keepdims=True),
mesh=out_sharding.mesh,
in_specs=in_sharding.spec,
out_specs=out_sharding.spec,
)(pred) # perform per-device reduction
error_code = _error_storage.ref[...]
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
error_code = jnp.where(should_update, new_error_code, error_code)
@ -93,7 +166,7 @@ def raise_if_error() -> None:
if _error_storage.ref is None: # if not initialized, do nothing
return
error_code = _error_storage.ref[...]
error_code = _error_storage.ref[...].min() # reduce to a single error code
if isinstance(error_code, core.Tracer):
raise ValueError(
"raise_if_error() should not be called within a traced context, such as"
@ -101,7 +174,11 @@ def raise_if_error() -> None:
)
if error_code == jnp.uint32(_NO_ERROR):
return
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
_error_storage.ref[...] = jnp.full(
_error_storage.ref.shape,
jnp.uint32(_NO_ERROR),
device=_error_storage.ref.sharding,
) # clear the error code
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)

View File

@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {}
spec_types: set[type] = {JumbleAxis}
def unregister_vmappable(data_type: type) -> None:
spec_type, axis_size_type = vmappables.pop(data_type)
spec_types.remove(spec_type)
_, axis_size_type = vmappables.pop(data_type)
del to_elt_handlers[data_type]
del from_elt_handlers[data_type]
if axis_size_type in make_iota_handlers:
del make_iota_handlers[axis_size_type]
global spec_types
spec_types = (
{JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
)
def is_vmappable(x: Any) -> bool:
return type(x) is Jumble or type(x) in vmappables

View File

@ -797,7 +797,7 @@ def tracers_to_jaxpr(
processed_eqn_ids = set()
eqns: list[core.JaxprEqn] = []
for t in toposort([*in_tracers, *out_tracers]):
for t in toposort((*in_tracers, *out_tracers)):
r = t.recipe
if isinstance(r, JaxprEqnRecipe):
# TODO broadcast_in_dim can create a new tracer, not present in parents

View File

@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
if len(bufs) == len(xs) > 0:
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)

View File

@ -1026,24 +1026,101 @@ def clz(x: ArrayLike) -> Array:
r"""Elementwise count-leading-zeros."""
return clz_p.bind(x)
@export
def add(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise addition: :math:`x + y`."""
r"""Elementwise addition: :math:`x + y`.
This function lowers directly to the `stablehlo.add`_ operation.
Args:
x, y: Input arrays. Must have matching numerical dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the sum
of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.add`: NumPy-style addition supporting inputs
with mixed dtypes and ranks.
.. _stablehlo.add: https://openxla.org/stablehlo/spec#add
"""
return add_p.bind(x, y)
@export
def sub(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise subtraction: :math:`x - y`."""
r"""Elementwise subtraction: :math:`x - y`.
This function lowers directly to the `stablehlo.subtract`_ operation.
Args:
x, y: Input arrays. Must have matching numerical dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the difference
of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.subtract`: NumPy-style subtraction supporting
inputs with mixed dtypes and ranks.
.. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract
"""
return sub_p.bind(x, y)
@export
def mul(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise multiplication: :math:`x \times y`."""
r"""Elementwise multiplication: :math:`x \times y`.
This function lowers directly to the `stablehlo.multiply`_ operation.
Args:
x, y: Input arrays. Must have matching numerical dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the product
of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.multiply`: NumPy-style multiplication supporting
inputs with mixed dtypes and ranks.
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return mul_p.bind(x, y)
@export
def div(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise division: :math:`x \over y`.
Integer division overflow
(division by zero or signed division of INT_SMIN with -1)
produces an implementation defined value.
This function lowers directly to the `stablehlo.divide`_ operation.
Integer division overflow (division by zero or signed division of
INT_SMIN with -1) produces an implementation defined value.
Args:
x, y: Input arrays. Must have matching numerical dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the quotient
of each pair of broadcasted entries. For integer inputs, any fractional
part is discarded.
See also:
- :func:`jax.numpy.divide`: NumPy-style true division supporting
inputs with mixed dtypes and ranks.
- :func:`jax.numpy.floor_divide`: NumPy-style floor division supporting
inputs with mixed dtypes and ranks.
.. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide
"""
return div_p.bind(x, y)
@ -8422,3 +8499,13 @@ mlir.register_lowering(optimization_barrier_p,
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
def _opt_barrier_jvp(primals, tangents):
tangents = [ad.instantiate_zeros(t) for t in tangents]
return optimization_barrier(primals), optimization_barrier(tangents)
ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp
def _opt_barrier_transpose(cts, *primals):
cts = [ad.instantiate_zeros(ct) for ct in cts]
return optimization_barrier(cts)
ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose

View File

@ -565,5 +565,5 @@ def use_concrete_mesh(mesh: Mesh | None):
finally:
jax_config.device_context.set_local(prev_val)
def get_concrete_mesh():
def get_concrete_mesh() -> Mesh | None:
return jax_config.device_context.value

View File

@ -15,6 +15,7 @@
"""Module for pallas-core functionality."""
from __future__ import annotations
import collections
from collections.abc import Callable, Iterable, Iterator, Sequence
import contextlib
import copy
@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
return [], effs
class Mesh(Protocol):
@property
def backend(self) -> str:
...
@property
def shape(self) -> collections.OrderedDict[object, int]:
...
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule(
in_avals,
out_avals,
*args,
grid,
mesh,
compiler_params,
backend,
jaxpr,
debug,
interpret,
@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule(
if isinstance(eff, state_types.WriteEffect)
)
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
grid_spec = GridSpec(
grid=tuple(mesh.shape.items()),
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(modified_idxs),
)
from jax._src.pallas import pallas_call # Avoid circular dependency.
outs = pallas_call.pallas_call(
outs = pallas_call._pallas_call(
body,
name=name,
out_shape=[in_avals[idx] for idx in modified_idxs],
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(modified_idxs),
input_output_aliases={
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
},
grid=grid,
grid_spec=grid_spec,
mesh=mesh,
compiler_params=compiler_params,
backend=backend,
interpret=interpret,
debug=debug,
cost_estimate=cost_estimate,

View File

@ -340,11 +340,12 @@ def pallas_call_hlo_interpret(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: Any,
cost_estimate: CostEstimate,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del compiler_params, cost_estimate, out_avals
del mesh, compiler_params, cost_estimate, out_avals
debug_info = jaxpr.debug_info
# If we're in interpret mode, we *scan* over the grid and eval the
# discharged jaxpr.

View File

@ -211,6 +211,10 @@ class TensorCoreMesh:
devices: np.ndarray
axis_names: Sequence[str]
@property
def backend(self) -> str:
return "mosaic_tpu"
@property
def shape(self):
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule(
compiler_params = TPUCompilerParams()
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
if compiler_params.dimension_semantics is not None:
raise ValueError(
"dimension_semantics must be None for TensorCoreMesh"
@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule(
out_avals,
*args,
jaxpr=jaxpr,
grid=((core_axis_name, num_cores),),
mesh=mesh,
compiler_params=compiler_params.replace(
dimension_semantics=(PARALLEL,)
),
debug=debug,
interpret=interpret,
backend="mosaic_tpu",
cost_estimate=cost_estimate,
name=name,
)

View File

@ -1351,12 +1351,13 @@ def interpret_pallas_call(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: Any,
cost_estimate: CostEstimate,
out_avals: tuple[jax_core.AbstractValue, ...],
interpret_params: TPUInterpretParams,
):
del debug, cost_estimate, out_avals
del debug, mesh, cost_estimate, out_avals
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
dynamic_grid_args, scalars, input_args = split_list(

View File

@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
*in_nodes,
jaxpr: jax_core.Jaxpr,
grid_mapping: core.GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
out_avals: tuple[jax_core.AbstractValue, ...],
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
del mesh, interpret # Unused.
debug_info = jaxpr._debug_info
if debug:
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
else:
mosaic_params = {}
mesh = None
jax_mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
mesh = axis_context.mesh
jax_mesh = axis_context.mesh
mlir_ctx = mlir.JaxIrContext()
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
mlir_ctx.load_all_available_dialects()
@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule(
grid_mapping,
jaxpr,
dimension_semantics=dimension_semantics,
mesh=mesh,
mesh=jax_mesh,
for_verification=for_verification,
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
)
@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule(
)
if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size
num_devices = 1 if jax_mesh is None else jax_mesh.devices.size
num_cores = (
jax.devices()[0].num_cores
if mesh is None
else mesh.devices[0].num_cores
if jax_mesh is None
else jax_mesh.devices[0].num_cores
)
verification_module, _ = lower_module(for_verification=True)
model = verification.export_promela_model(

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import abc
import collections
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
import dataclasses
import enum
import itertools as it
@ -519,9 +519,16 @@ class GPUMesh:
)
@property
def shape(self):
def backend(self) -> str:
return "mosaic_gpu"
@property
def shape(self) -> collections.OrderedDict[object, int]:
pairs: Iterable[tuple[object, int]]
if self.num_threads is not None:
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
pairs = zip(
self.axis_names, (*self.grid, *self.cluster, self.num_threads)
)
else:
pairs = tuple(
zip(
@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule(
out_avals,
*args,
jaxpr=jaxpr,
grid=tuple(mesh.shape.items()),
backend="mosaic_gpu",
mesh=mesh,
compiler_params=compiler_params,
debug=debug,
interpret=interpret,

View File

@ -450,6 +450,7 @@ def _block_spec_from_block_mapping(
def lower_pipelined_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
jaxpr: jax_core.Jaxpr,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module(
block_mappings, [grid_mapping.num_inputs]
)
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
if mesh is not None:
assert isinstance(mesh, gpu_core.GPUMesh)
if mesh and mesh.num_threads is not None:
# Last dim corresponds to the warpgroup count.
block = (128 * grid_mapping.grid[-1], 1, 1)
grid = grid_mapping.grid[:-1]
else:
@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module(
parallel_grid,
grid_mapping.grid_names,
block,
mesh.cluster if mesh is not None else (),
[bm.array_shape_dtype for bm in in_block_mappings],
[bm.array_shape_dtype for bm in out_block_mappings],
new_jaxpr,
@ -578,6 +583,7 @@ def lower_jaxpr_to_module(
grid: Sequence[int],
grid_names: Sequence[str],
block: Sequence[int],
cluster: Sequence[int],
in_shapes: Sequence[jax.ShapeDtypeStruct],
out_shapes: Sequence[jax.ShapeDtypeStruct],
jaxpr: jax_core.Jaxpr,
@ -640,7 +646,7 @@ def lower_jaxpr_to_module(
mgpu_core._lower_as_gpu_kernel(
body,
grid=parallel_grid,
cluster=(),
cluster=cluster,
block=block,
in_shapes=in_shapes,
out_shape=out_shapes,
@ -1559,9 +1565,10 @@ def _reduce_lowering_rule_wg(
if not out_aval.shape:
# Special-case: reducing to a scalar.
if x_aval.ndim != 1:
# TODO(slebedev): Flatten to 1D, since vector.reduction only supports
# 1D inputs.
raise NotImplementedError("Only 1D inputs are supported")
# Flatten to 1D, since vector.reduction only supports 1D inputs.
x = vector_dialect.shape_cast(
ir.VectorType.get([x_aval.size], out_type), x
)
return vector_dialect.ReductionOp(out_type, kind, x)
acc = vector_dialect.splat(
ir.VectorType.get(out_aval.shape, out_type),

View File

@ -38,6 +38,7 @@ def pallas_call_lowering(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
@ -63,6 +64,7 @@ def pallas_call_lowering(
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
grid_mapping,
mesh,
jaxpr,
compiler_params,
cost_estimate,

View File

@ -20,7 +20,7 @@ import dataclasses
import enum
from functools import partial, reduce
import types
from typing import Any, Literal
from typing import Any, Literal, cast
import jax
from jax import lax
@ -119,6 +119,7 @@ def _pallas_call_jvp_rule(
jaxpr: jax_core.Jaxpr,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
debug: bool,
interpret: bool,
compiler_params: Any,
@ -133,6 +134,8 @@ def _pallas_call_jvp_rule(
raise NotImplementedError
if input_output_aliases:
raise NotImplementedError("JVP with aliasing not supported.")
if mesh is not None:
raise NotImplementedError("pallas_call with a mesh does not support JVP")
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
@ -181,6 +184,7 @@ def _pallas_call_jvp_rule(
*tangents,
jaxpr=jvp_jaxpr,
grid_mapping=jvp_grid_mapping,
mesh=mesh,
interpret=interpret,
debug=debug,
input_output_aliases=(),
@ -317,6 +321,7 @@ def _batch_with_explicit_loop(
*,
jaxpr: jax_core.Jaxpr,
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
@ -384,6 +389,7 @@ def _batch_with_explicit_loop(
*batch_args,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -413,6 +419,7 @@ def _pallas_call_batching_rule(
*,
jaxpr: jax_core.Jaxpr,
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
@ -421,6 +428,11 @@ def _pallas_call_batching_rule(
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None,
):
if mesh is not None:
raise NotImplementedError(
"pallas_call with a mesh does not support batching"
)
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
) -> jax.Array:
@ -445,6 +457,7 @@ def _pallas_call_batching_rule(
*args,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -478,6 +491,7 @@ def _pallas_call_batching_rule(
dims=dynamic_grid_dims + dims,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -512,6 +526,7 @@ def _pallas_call_batching_rule(
dims=scalar_bdims + bdims,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -890,6 +905,7 @@ def _pallas_call_batching_rule(
*args,
jaxpr=jaxpr,
grid_mapping=batched_grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule(
jaxpr: jax_core.Jaxpr,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None = None
backend: _Backend | None = None,
):
del avals_out
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule(
jaxpr=new_jaxpr,
input_output_aliases=new_input_output_aliases,
grid_mapping=new_grid_mapping,
mesh=mesh,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
@ -1526,16 +1544,6 @@ def pallas_call(
invoke the Pallas kernel.
"""
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
compiler_params = {
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
}
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
else:
@ -1556,6 +1564,55 @@ def pallas_call(
"If `grid_spec` is specified, then `scratch_shapes` must "
f"be `()`. It is {scratch_shapes}")
del grid, in_specs, out_specs
return _pallas_call(
kernel,
out_shape,
grid_spec=grid_spec,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
name=name,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
backend=backend,
)
def _pallas_call(
kernel: Callable[..., None],
out_shape: Any,
*,
grid_spec: GridSpec,
mesh: pallas_core.Mesh | None = None,
input_output_aliases: dict[int, int] = {},
debug: bool = False,
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
cost_estimate: CostEstimate | None = None,
backend: _Backend | None = None,
):
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}"
)
compiler_params = {
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
}
if mesh is not None:
if tuple(mesh.shape.values()) != grid_spec.grid:
raise ValueError(
f"Mesh shape {tuple(mesh.shape.values())} does not match grid "
f"shape {grid_spec.grid}."
)
if backend is not None:
raise ValueError("If `mesh` is specified, then `backend` must be `None`.")
backend = cast(_Backend, mesh.backend)
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
# TODO(necula): this canonicalization may be convenient for some usage
# but it is lossy, because it prevents expressing functions that return
@ -1643,6 +1700,7 @@ def pallas_call(
debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=tuple(input_output_aliases.items()),
compiler_params=compiler_params,
cost_estimate=cost_estimate,

View File

@ -50,6 +50,7 @@ def pallas_call_lowering(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
@ -64,6 +65,8 @@ def pallas_call_lowering(
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
if mesh is not None:
raise NotImplementedError("mesh is not supported in the Triton backend")
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.get("num_warps", 4)
num_warps = 4 if num_warps is None else num_warps

View File

@ -670,8 +670,8 @@ def choice(key: ArrayLike,
ind = jnp.searchsorted(p_cuml, r).astype(int)
else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
ind = jnp.argsort(g)[:n_draws]
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
ind = lax.top_k(g, k=n_draws)[1].astype(int)
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
return result.reshape(shape if arr.ndim == 0 else
@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
def categorical(key: ArrayLike,
logits: RealArray,
axis: int = -1,
shape: Shape | None = None) -> Array:
def categorical(
key: ArrayLike,
logits: RealArray,
axis: int = -1,
shape: Shape | None = None,
replace: bool = True,
) -> Array:
"""Sample random values from categorical distributions.
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
the Gumbel top-k trick. See [1] for reference.
Args:
key: a PRNG key used as the random key.
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
replace: If True, perform sampling without replacement. Default (False) is to
perform sampling with replacement.
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
References:
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
Proceedings of the 36th International Conference on Machine Learning, PMLR
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
"""
key, _ = _check_prng_key("categorical", key)
check_arraylike("categorical", logits)
logits_arr = jnp.asarray(logits)
if axis >= 0:
axis -= len(logits_arr.shape)
batch_shape = tuple(np.delete(logits_arr.shape, axis))
if shape is None:
shape = batch_shape
else:
shape = core.canonicalize_shape(shape)
_check_shape("categorical", shape, batch_shape)
shape_prefix = shape[:len(shape)-len(batch_shape)]
logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
return jnp.argmax(
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis)
if replace:
if axis >= 0:
axis -= len(logits_arr.shape)
logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
return jnp.argmax(
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis)
else:
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
k = math.prod(shape_prefix)
if k > logits_arr.shape[axis]:
raise ValueError(
f"Number of samples without replacement ({k}) cannot exceed number of "
f"categories ({logits_arr.shape[axis]})."
)
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
assert indices.shape == batch_shape + (k,)
assert shape == shape_prefix + batch_shape
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
indices = lax.reshape(indices, shape, dimensions)
assert indices.shape == shape
return indices
def laplace(key: ArrayLike,

View File

@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
return sdy_sharding
@util.cache(max_size=128, trace_context_in_key=False)
def get_replicated_hlo_sharding():
return xc.HloSharding.replicate()
replicated_hlo_sharding = xc.HloSharding.replicate()
@use_cpp_class(xc.SingleDeviceSharding)
@ -183,7 +181,7 @@ class SingleDeviceSharding(jsharding.Sharding):
return (self._device,)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return get_replicated_hlo_sharding()
return replicated_hlo_sharding
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding(
def _positional_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
if self.shape == (1,) * self.ndim:
return get_replicated_hlo_sharding()
return replicated_hlo_sharding
pbuf = xc.OpSharding()
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
@ -603,7 +601,7 @@ class GSPMDSharding(jsharding.Sharding):
@functools.cached_property
def _hlo_sharding_hash(self):
if self.is_fully_replicated:
return hash(get_replicated_hlo_sharding())
return hash(replicated_hlo_sharding)
return hash(self._hlo_sharding)
def __eq__(self, other):
@ -669,7 +667,7 @@ class GSPMDSharding(jsharding.Sharding):
@classmethod
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
return cls(tuple(device_assignment), replicated_hlo_sharding,
memory_kind=memory_kind)

View File

@ -244,52 +244,62 @@ def curry(f):
"""
return wraps(f)(partial(partial, f))
def toposort(end_nodes):
if not end_nodes: return []
end_nodes = _remove_duplicates(end_nodes)
# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
toposort: Callable[[Iterable[Any]], list[Any]]
if hasattr(jaxlib_utils, "topological_sort"):
toposort = partial(jaxlib_utils.topological_sort, "parents")
else:
child_counts = {}
stack = list(end_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(node.parents)
for node in end_nodes:
child_counts[id(node)] -= 1
def toposort(end_nodes):
if not end_nodes:
return []
end_nodes = _remove_duplicates(end_nodes)
sorted_nodes = []
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
assert childless_nodes
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
child_counts = {}
stack = list(end_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
child_counts[id(node)] = 1
stack.extend(node.parents)
for node in end_nodes:
child_counts[id(node)] -= 1
check_toposort(sorted_nodes)
return sorted_nodes
sorted_nodes = []
childless_nodes = [
node for node in end_nodes if child_counts[id(node)] == 0
]
assert childless_nodes
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
def check_toposort(nodes):
visited = set()
for node in nodes:
assert all(id(parent) in visited for parent in node.parents)
visited.add(id(node))
check_toposort(sorted_nodes)
return sorted_nodes
def check_toposort(nodes):
visited = set()
for node in nodes:
assert all(id(parent) in visited for parent in node.parents)
visited.add(id(node))
def _remove_duplicates(node_list):
seen = set()
out = []
for n in node_list:
if id(n) not in seen:
seen.add(id(n))
out.append(n)
return out
def _remove_duplicates(node_list):
seen = set()
out = []
for n in node_list:
if id(n) not in seen:
seen.add(id(n))
out.append(n)
return out
def split_merge(predicate, xs):
sides = list(map(predicate, xs))
@ -658,17 +668,12 @@ def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]:
exclude_methods = {'__module__', '__dict__', '__doc__'}
originals = {}
for attr_name, attr in cls.__dict__.items():
if attr_name not in exclude_methods:
if hasattr(_original_func(attr), "_use_cpp"):
originals[attr_name] = attr
else:
if not hasattr(_original_func(attr), "_use_cpp"):
setattr(cpp_cls, attr_name, attr)
cpp_cls.__doc__ = cls.__doc__
# TODO(pschuh): Remove once fastpath is gone.
cpp_cls._original_py_fns = originals
return cpp_cls
return wrapper

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for serialization and deserialization of GDA."""
import asyncio
import math

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mnist_lib, saved_model_lib, saved_model_main."""
import os
from absl import flags

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for call_tf."""
from collections.abc import Callable
import contextlib

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the jax2tf conversion for control-flow primitives."""
from absl.testing import absltest

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the shape-polymorphic jax2tf conversion."""
from __future__ import annotations

View File

@ -320,6 +320,20 @@ def _vector_splat_op_lowering_rule(
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
@_register_lowering(vector.ShapeCastOp)
def _vector_shape_cast_op_lowering_rule(
_: LoweringContext, op: vector.ShapeCastOp
) -> Sequence[ir.Value]:
[layout] = inference_utils.in_layouts(op)
out_vec_ty = ir.VectorType(op.result.type)
assert out_vec_ty.has_static_shape
is_signed = (
False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None
)
a = _fragmented_array_from_ir(op.source, layout, is_signed)
return [_fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)]
@_register_lowering(vector.ReductionOp)
def _vector_reduction_op_lowering_rule(
ctx: LoweringContext, op: vector.ReductionOp

View File

@ -382,21 +382,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]):
return WGMMA_LAYOUT
def _tiled_wgmma_layout_for_upcast(shape: tuple[int, ...]):
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
if len(shape) != 2:
raise ValueError(f"Shape {shape} is not 2D")
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
t = Tiling(((64, 16), (16, 16), (8, 16), (4,), (2, 1)))
return TiledLayout(
t,
warp_dim=-9,
lane_dims=(-5, -2, -4),
vector_dim=-3,
)
@dataclasses.dataclass(frozen=True)
class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
@ -505,13 +490,55 @@ WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
# The tiled layout is equivalent to one described here in PTX documentation:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles.
# Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit
# of data that is split across a warp. Since 8*8 = 64, but a warp has only 32
# threads, we vectorize pairs of elements along columns.
# The assignment of elements to warp lanes is as follows:
#
# 0 0 1 1 2 2 3 3
# 4 4 5 5 6 6 7 7
# 8 8 9 9 10 10 11 11
# 12 12 13 13 14 14 15 15
# ...
WGMMA_LAYOUT = TiledLayout(
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
warp_dim=-8,
lane_dims=(-4, -3),
vector_dim=-1,
)
# This tiled layout is similar to the one above. Above, each warp stores a 8x8
# This tiled layout is similar to the WGMMA layout, only the unit at which we
# assign submatrices to warps grows from 8x8 to 8x16. The elements within each
# submatrix are assigned to threads in the following way:
#
# 0 0 0 0 2 2 2 2 1 1 1 1 3 3 3 3
# 4 4 4 4 6 6 6 6 5 5 5 5 7 7 7 7
# ...
#
# Our vector length is twice the size of that of WGMMA_LAYOUT, which lets us use
# 32-bit SMEM loads/stores when dealing with 8-bit values. The conversion
# to the WGMMA layout only requires communication between with index differing
# in their 2 bit (i.e. 0 and 1, 2 and 4), so the conversion to WGMMA_LAYOUT
# only requires a single warp shuffle (plus permutes local to each thread).
WGMMA_LAYOUT_UPCAST_2X = TiledLayout(
Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))),
warp_dim=-8,
lane_dims=(-4, -2, -3),
vector_dim=-1,
)
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
# purpose of passing them into WGMMA later. The core matrices stored by a warp
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
warp_dim=-7,
lane_dims=(-3, -2),
vector_dim=-1,
)
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
# submatrix in the following way (we only show the first 4 rows for brevity):
#
# 0 0 1 1 2 2 3 3
@ -697,6 +724,7 @@ class FragmentedArray:
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
"""
i32 = ir.IntegerType.get_signless(32)
c = lambda x: arith.constant(i32, x)
if self.layout == new_layout:
return self
shape = self.shape
@ -707,24 +735,148 @@ class FragmentedArray:
):
is_even_row = arith.cmpi(
arith.CmpIPredicate.eq,
arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)),
c(0, i32),
arith.remui(arith.divui(utils.thread_idx(), c(4)), c(2)),
c(0),
)
perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32))
perm = arith.select(is_even_row, c(0x5410), c(0x3276))
new_regs = []
for reg in self.registers.flat:
reg_ty = reg.type
reg = utils.bitcast(reg, i32)
reg_shfl = utils.shfl_bfly(reg, 4)
new_reg = llvm.inline_asm(
i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
)
new_reg = utils.prmt(reg, reg_shfl, perm)
new_regs.append(utils.bitcast(new_reg, reg_ty))
return FragmentedArray(
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
_layout=new_layout,
_is_signed=self.is_signed,
)
if (
self.layout == WGMMA_LAYOUT_UPCAST_2X
and new_layout == WGMMA_LAYOUT
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
):
assert shape[1] % 16 == 0 # Should be implied by the layout
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
is_even = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
)
registers = self.registers
if dtype_bitwidth == 4:
if registers.shape[1] % 2:
raise NotImplementedError(
"This relayout implementation requires an even number of column"
" tiles (to pack pairs of them for efficiency)"
)
# We pair up the consecutive column tiles, so each register is 32-bit.
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
# LLVM will realize that the paired up vectors actually came from the
# same 32-bit register and it will become a no-op.
col_minor_registers = np.moveaxis(registers, 1, -1)
flat_registers = [
utils.vector_concat((l, h))
for l, h in zip(
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
)
]
registers = np.asarray(flat_registers, dtype=object).reshape(
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
)
registers = np.moveaxis(registers, -1, 1)
for idx, reg in np.ndenumerate(registers):
if dtype_bitwidth == 16:
assert reg.type.shape == [4]
# A single vector is 64-bits, but shuffles are only 32-bit wide.
# We only shuffle the half that needs to go to other thread.
low = utils.vector_slice(reg, slice(0, 2))
high = utils.vector_slice(reg, slice(2, 4))
to_exchange = arith.select(is_even, high, low)
# Exchange values between even and odd threads.
exchanged = utils.shfl_bfly(to_exchange, 1)
low = arith.select(is_even, low, exchanged)
high = arith.select(is_even, exchanged, high)
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
elif dtype_bitwidth == 8:
assert reg.type.shape == [4]
# The vector is 32-bits, so we just shuffle the whole thing and
# use prmt to blend it with the local register.
exchanged = utils.shfl_bfly(reg, 1)
# Consider lanes 0 and 1, because the situation is symmetric for
# each pair. If we feed reg[lane] and exchanged[lane] (which is
# really the same as reg of the other lane) to prmt, we can index
# the elements of the result using the following indices:
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
# prmt[0]: 0 1 2 3 4 5 6 7
# prmt[1]: 4 5 6 7 0 1 2 3
# The expected outputs and their respective permutations are:
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
# Note that the patterns still need to be flipped, since we listed
# bytes with LSB on the left, which is the opposite of how the
# numeric constants are spelled in Python (LSB on the right).
perm = arith.select(is_even, c(0x5410), c(0x3276))
blend = utils.prmt(reg, exchanged, perm)
for i in range(2):
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
else:
assert dtype_bitwidth == 4
assert reg.type.shape == [8] # We paired up the registers above.
exchanged = utils.shfl_bfly(reg, 1)
# See comment above for a more complete explanation.
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
# The expected outputs and their respective permutations are:
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
perm = arith.select(is_even, c(0x6240), c(0x3715))
blend = utils.prmt(reg, exchanged, perm)
for i in range(4):
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
assert all(r is not None for r in new_registers)
return FragmentedArray(
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
)
if (
self.layout == WGMMA_LAYOUT_UPCAST_4X
and new_layout == WGMMA_LAYOUT_UPCAST_2X
and utils.bitwidth(self.mlir_dtype) == 4
):
assert shape[0] % 64 == 0 # Should be implied by the layout
assert shape[1] % 32 == 0 # Should be implied by the layout
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
i32 = ir.IntegerType.get_signless(32)
c = lambda x: arith.constant(i32, x)
is_01 = arith.cmpi(
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
)
for idx, reg in np.ndenumerate(self.registers):
assert ir.VectorType(reg.type).shape == [8]
# The vector is 32-bits, so we just shuffle the whole thing and
# use prmt to blend it with the local register.
exchanged = utils.shfl_bfly(reg, 2)
# See comments above for conventions. Here we exchange data between
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
# The expected outputs and their respective permutations are:
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
perm = arith.select(is_01, c(0x5410), c(0x3276))
blend = utils.prmt(reg, exchanged, perm)
for i in range(2):
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
assert all(r is not None for r in new_registers)
return FragmentedArray(
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
)
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
if not isinstance(self.layout, WGSplatFragLayout):
raise NotImplementedError(
f"Cannot convert from {self.layout} to {new_layout}"
@ -1178,11 +1330,15 @@ class FragmentedArray:
is_vector_reg = ir.VectorType.isinstance(reg_type)
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
[vector_len] = reg_shape # This is meant to be a 1D assertion.
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
raise ValueError(
"Register bitwidth in target type must be divisible by 8, got"
f" {new_reg_bitwidth}"
)
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
new_registers = np.empty_like(self.registers)
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
for idx, reg in np.ndenumerate(self.registers):
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
# The algorithm here is largely the same as CUTLASS's
# NumericArrayConverter specialization for int4 -> bf16 casts.
# We modify it slightly, because we only extract 2 values.
@ -1196,25 +1352,58 @@ class FragmentedArray:
# positive int4s will end up larger than negative int4s, with a bias of
# 8. Use use the sub to subtract the base (our initial exponent) and the
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
new_reg_32 = llvm.inline_asm(
i32,
[reg_8],
"""
{
.reg .b32 s<4>;
shr.s32 s0, $1, 4;
prmt.b32 s1, $1, s0, 0xF4F0;
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
mov.b32 s3, 0x43084308;
sub.bf16x2 $0, s2, s3;
}
""",
"=r,r",
)
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
new_registers[idx] = vector.bitcast(
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
)
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
assert 0 <= part < 4
return llvm.inline_asm(
i32,
[reg, reg_shr],
f"""
{{
.reg .b32 s<4>;
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
mov.b32 s3, 0x43084308;
sub.bf16x2 $0, s2, s3;
}}
""",
"=r,r,r",
)
offset = 0
out_int_regs = []
for group_size in (8, 4, 2):
int_ty = ir.IntegerType.get_signless(group_size * 4)
while vector_len - offset >= group_size:
# If the vector originates from a slice (common after relayouts), we
# can fuse the slicing into the conversion and prevent LLVM from
# generating a bunch of shifts to align the vector data to the LSB.
# This also lets us share the right shift among more vectors.
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
and utils.bitwidth(slice_op.vector.type) == 32
and slice_op.strides[0].value == 1):
slice_offset = slice_op.offsets[0].value + offset
reg_int = utils.bitcast(slice_op.vector, i32)
reg_int_shr = arith.shrui(reg_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
for part in range(group_size // 2)
)
else:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = utils.bitcast(reg_slice, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
for part in range(group_size // 2)
)
offset += group_size
assert offset == vector_len
out_vec_int = utils.vector_concat([
vector.splat(ir.VectorType.get((1,), i32), reg)
for reg in out_int_regs
])
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
return FragmentedArray(
_registers=new_registers, _layout=self.layout, _is_signed=None
)
@ -1263,11 +1452,6 @@ class FragmentedArray:
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
)
# Generic path.
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
# same endianness as the target machine (which is little for NVIDIA GPUs).
# We'll need to add specialized casting routines that flip the endianness.
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
raise NotImplementedError("Conversion involving sub-byte types unsupported")
from_float = ir.FloatType.isinstance(cur_dtype)
to_float = ir.FloatType.isinstance(new_dtype)
from_integer = ir.IntegerType.isinstance(cur_dtype)
@ -1472,17 +1656,17 @@ class FragmentedArray:
def reshape(self, shape):
if self.shape == shape:
return self
if not isinstance(self.layout, WGSplatFragLayout):
raise NotImplementedError(self.layout)
if np.prod(shape) != np.prod(self.shape):
if math.prod(shape) != math.prod(self.shape):
raise ValueError(f"Can't reshape {self.shape} to {shape}")
match self.layout:
case WGSplatFragLayout() | WGStridedFragLayout():
new_layout = dataclasses.replace(self.layout, shape=shape)
case _:
raise NotImplementedError(self.layout)
return FragmentedArray(
_registers=self.registers,
_layout=WGSplatFragLayout(shape),
_is_signed=self.is_signed,
_registers=self.registers, _layout=new_layout, _is_signed=self.is_signed
)
def broadcast_minor(self, n):

View File

@ -336,6 +336,37 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
return [], [layout]
def _update_layout_shape(
layout: ir.Attribute, shape: Sequence[int], origin: str
) -> ir.Attribute:
if layouts_lib.is_splat_fragmented_layout(
layout
) or layouts_lib.is_strided_fragmented_layout(layout):
return layouts_lib.to_layout_attr(
dataclasses.replace(layouts_lib.from_layout_attr(layout), shape=shape)
)
raise NotImplementedError(f"Unsupported {origin} layout: {layout}.")
@partial(_add_layout_inference_rule, vector.ShapeCastOp)
def _infer_shape_cast_op_layout(op: vector.ShapeCastOp) -> OptionalLayouts:
in_layout = inference_utils.value_layout(op.source)
if in_layout is None:
out_layout = inference_utils.value_layout(op.result)
if out_layout is None:
return None
in_layout = _update_layout_shape(
out_layout, ir.VectorType(op.source.type).shape, "source"
)
return [in_layout], [out_layout]
out_layout = _update_layout_shape(
in_layout, ir.VectorType(op.result.type).shape, "result"
)
return [in_layout], [out_layout]
@partial(_add_layout_inference_rule, vector.ReductionOp)
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
if layout := inference_utils.value_layout(op.vector):

View File

@ -83,6 +83,8 @@ def mma(
accumulate: ir.Value | bool = True,
collective: bool = False,
):
if a_swizzle == 16 or b_swizzle == 16:
raise NotImplementedError("No swizzle is not supported")
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
if isinstance(accumulate, bool):

View File

@ -25,8 +25,12 @@ from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa
from . import inference_utils
from . import layouts as layouts_lib
from . import utils
# mypy: ignore-errors
@ -39,7 +43,9 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {}
def _add_transform_inference_rule(
op: type[ir.OpView], rule: TransformInferenceRule
):
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
if op is not None:
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
return rule
def _set_transform_attributes(
@ -110,6 +116,86 @@ def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms:
return None if in_transforms is None else ([in_transforms], [])
@partial(_add_transform_inference_rule, vector.LoadOp)
@partial(_add_transform_inference_rule, vector.StoreOp)
def _infer_vector_load_store_transforms(
op: vector.LoadOp | vector.StoreOp,
) -> OptionalTransforms:
for i in op.indices:
index_defining_op = i.owner.opview
if (
not isinstance(index_defining_op, arith.ConstantOp)
or index_defining_op.literal_value != 0
):
# TODO(bchetioui): handle slicing.
raise NotImplementedError(
f"Only constants with value 0 are supported as indices for {op}"
)
if isinstance(op, vector.LoadOp):
[layout_attr] = inference_utils.out_layouts(op)
else:
assert isinstance(op, vector.StoreOp)
[layout_attr] = inference_utils.in_layouts(op)
layout = layouts_lib.from_layout_attr(layout_attr)
transforms = inference_utils.value_transforms(op.base)
if layout == fa.WGMMA_LAYOUT:
layout_transforms = infer_transforms_for_wgmma_ref(
ir.MemRefType(op.base.type)
)
elif (isinstance(layout, fa.WGStridedFragLayout) or
isinstance(layout, fa.WGSplatFragLayout)):
layout_transforms = None
else:
raise NotImplementedError(
f"Got layout {layout} which is not yet supported"
)
if transforms is not None and layout_transforms is not None:
if transforms != layout_transforms:
raise NotImplementedError(
f"Conflicting transforms for {op.base} in {op}: "
f"{transforms} != {layout_transforms}."
)
return [transforms], []
if transforms is not None:
return [transforms], []
if layout_transforms is not None:
return [layout_transforms], []
return None
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
@partial(_add_transform_inference_rule, SliceSMEMOp)
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
transforms = None
uses = cast(ir.OpResult, op.result).uses
for op_operand_use in uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
out_transforms = inference_utils.in_transforms_for_operand(
consumer, op_user
)
if transforms is not None and out_transforms is not None:
if transforms != out_transforms:
raise NotImplementedError(
f"Conflicting transforms for {op_user} in {op}: "
f"{transforms} != {out_transforms}."
)
elif out_transforms is not None:
transforms = out_transforms
return None if transforms is None else ([], [transforms])
def _should_have_transforms(op: ir.OpView) -> bool:
"""Returns 'True' if the operation should be assigned in/out transforms."""
return any(

View File

@ -346,8 +346,11 @@ def bitwidth_impl(ty: ir.Type):
return ir.IntegerType(ty).width
if ir.FloatType.isinstance(ty):
return ir.FloatType(ty).width
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
return MBARRIER_BYTES * 8
if ir.VectorType.isinstance(ty):
vty = ir.VectorType(ty)
return math.prod(vty.shape) * bitwidth(vty.element_type)
raise NotImplementedError(ty)
@ -1180,13 +1183,33 @@ def shfl_bfly(x: ir.Value, distance: int | ir.Value):
i32 = ir.IntegerType.get_signless(32)
if isinstance(distance, int):
distance = c(distance, i32)
assert x.type == i32
return nvvm.shfl_sync(
if (result_type := x.type) != i32:
x = bitcast(x, i32)
y = nvvm.shfl_sync(
i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly,
)
return bitcast(y, result_type)
def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value):
i32 = ir.IntegerType.get_signless(32)
if (result_type := high.type) != low.type:
raise ValueError(f"Types must match, got {high.type} and {low.type}")
if high.type != i32:
high = bitcast(high, i32)
if low.type != i32:
low = bitcast(low, i32)
if permutation.type != i32:
permutation = bitcast(permutation, i32)
result = llvm.inline_asm(
i32, [high, low, permutation], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
)
return bitcast(result, result_type)
def bitcast(x: ir.Value, new_type: ir.Type):
if x.type == new_type:
return x
if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type):
new_type = ir.IntegerType(new_type)
x_ty = ir.VectorType(x.type)
@ -1200,8 +1223,50 @@ def bitcast(x: ir.Value, new_type: ir.Type):
x_ty = ir.IntegerType(x.type)
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
x_ty = ir.VectorType(x.type)
new_ty = ir.VectorType(new_type)
if bitwidth(x_ty) != bitwidth(new_ty):
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
return vector.bitcast(new_type, x)
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
def ceil_div(x: int, y: int):
return (x + y - 1) // y
def vector_slice(v: ir.Value, s: slice):
v_ty = ir.VectorType(v.type)
if len(v_ty.shape) != 1:
raise NotImplementedError(v_ty)
[v_len] = v_ty.shape
slice_length = len(range(v_len)[s])
return vector.extract_strided_slice(
ir.VectorType.get((slice_length,), v_ty.element_type),
v, [s.start or 0], [slice_length], [1],
)
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
index = ir.IndexType.get()
if not vectors:
raise ValueError("Cannot concatenate an empty list of vectors")
vty = vectors[0].type
if not ir.VectorType.isinstance(vty):
raise ValueError("Cannot concatenate non-vector values")
if vty.rank != 1:
raise NotImplementedError("Only 1D vectors are supported")
for v in vectors:
if v.type != vty:
raise ValueError("Cannot concatenate vectors of different types")
result = llvm.mlir_undef(
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
)
offset = 0
for v in vectors:
for i in range(vty.shape[0]):
elem = vector.extractelement(v, position=c(i, index))
result = vector.insertelement(elem, result, position=c(offset + i, index))
offset += vty.shape[0]
return result

View File

@ -259,6 +259,8 @@ def wgmma(
The refs must be contiguous or be contiguous except for having their two minor
dimensions swapped.
"""
if swizzle == 16:
raise NotImplementedError("No swizzle is not supported")
# Step 1. Establish the shape and element type of the operation.
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")

View File

@ -214,6 +214,8 @@ nanobind_extension(
module_name = "utils",
deps = [
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/synchronization",
"@nanobind",

View File

@ -65,6 +65,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:VectorDialect",
],
)

View File

@ -238,11 +238,12 @@ NB_MODULE(_mosaic_gpu_ext, m) {
"failed to enable tracking of kernel activity by CUPTI");
});
m.def("_cupti_get_timings", []() {
THROW_IF_CUPTI_ERROR(
cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED),
"failed to flush CUPTI activity buffers");
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
"failed to unsubscribe from CUPTI");
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
"failed to flush CUPTI activity buffers");
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
return profiler_state.timings;
});
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/gpu/passes.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
@ -23,6 +24,7 @@ limitations under the License.
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/SymbolTable.h"
@ -36,6 +38,49 @@ namespace gpu {
namespace {
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
struct ConvertExtractStridedSlicePattern final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
mlir::ConversionPatternRewriter &rewriter) const override {
auto vty = op.getSourceVectorType();
if (vty.getRank() != 1) {
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
}
int64_t size =
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
if (size < 0) {
return rewriter.notifyMatchFailure(op, "size is negative");
}
int64_t start =
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
int64_t stride =
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
if (stride != 1) {
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
}
if (start < 0 || start + size > vty.getShape()[0]) {
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
}
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
op.getLoc(), op.getResult().getType());
for (int64_t i = 0; i < size; ++i) {
result = rewriter.create<mlir::LLVM::InsertElementOp>(
op.getLoc(), result,
rewriter.create<mlir::LLVM::ExtractElementOp>(
op.getLoc(), subst.getVector(),
rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(i)));
}
rewriter.replaceOp(op, result);
return mlir::success();
}
};
class ConvertGpuToLLVMPass
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
public:
@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
});
auto symtab = mlir::SymbolTable(getOperation());
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
if (mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))
.failed()) {

View File

@ -16,9 +16,13 @@ limitations under the License.
#include <Python.h>
#include <cstddef>
#include <utility>
#include <vector>
#include "nanobind/nanobind.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/synchronization/mutex.h"
@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = {
METH_FASTCALL,
};
nb::list TopologicalSort(nb::str parents_attr,
nb::iterable end_nodes_iterable) {
// This is a direct conversion of the original Python implementation.
// More efficient implementations of a topological sort are possible (and
// indeed, easier to write), but changing the choice of topological order
// would break existing tests.
std::vector<nb::object> end_nodes;
absl::flat_hash_set<PyObject*> seen;
for (nb::handle n : end_nodes_iterable) {
nb::object node = nb::borrow(n);
if (seen.insert(node.ptr()).second) {
end_nodes.push_back(node);
}
}
nb::list sorted_nodes;
if (end_nodes.empty()) {
return sorted_nodes;
}
std::vector<nb::object> stack = end_nodes;
absl::flat_hash_map<PyObject*, int> child_counts;
while (!stack.empty()) {
nb::object node = std::move(stack.back());
stack.pop_back();
auto& count = child_counts[node.ptr()];
if (count == 0) {
for (nb::handle parent : node.attr(parents_attr)) {
stack.push_back(nb::borrow(parent));
}
}
++count;
}
for (nb::handle n : end_nodes) {
child_counts[n.ptr()] -= 1;
}
std::vector<nb::object> childless_nodes;
childless_nodes.reserve(end_nodes.size());
for (nb::handle n : end_nodes) {
if (child_counts[n.ptr()] == 0) {
childless_nodes.push_back(nb::borrow(n));
}
}
while (!childless_nodes.empty()) {
nb::object node = std::move(childless_nodes.back());
childless_nodes.pop_back();
sorted_nodes.append(node);
for (nb::handle parent : node.attr(parents_attr)) {
auto& count = child_counts[parent.ptr()];
if (count == 1) {
childless_nodes.push_back(nb::borrow(parent));
} else {
--count;
}
}
}
sorted_nodes.reverse();
return sorted_nodes;
}
} // namespace
NB_MODULE(utils, m) {
@ -304,6 +371,13 @@ NB_MODULE(utils, m) {
m.attr("safe_zip") = nb::steal<nb::object>(
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"),
nb::arg("end_nodes"),
"Computes a topological sort of a graph of objects. parents_attr is "
"the name of the attribute on each object that contains the list of "
"parent objects. end_nodes is an iterable of objects from which we "
"should start a backwards search.");
// Python has no reader-writer lock in its standard library, so we expose
// bindings around absl::Mutex.
nb::class_<absl::Mutex>(m, "Mutex")

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for AOT compilation."""
import contextlib
import unittest

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax.api_util."""
import itertools as it
from absl.testing import absltest

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Array."""
import contextlib
import math

View File

@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase):
self.assertEqual(ans.names, expected.names)
self.assertAllClose(ans.data, expected.data)
def test_types_with_same_spec(self):
# We register NamedArray.
batching.register_vmappable(NamedArray, NamedMapSpec, int,
named_to_elt, named_from_elt, None)
# We then register another type that uses NamedMapSpec as the spec_type too,
# and immediately unregister it.
class Foo:
pass
batching.register_vmappable(Foo, NamedMapSpec, int,
named_to_elt, named_from_elt, None)
batching.unregister_vmappable(Foo)
# We should still be able to use vmap on NamedArray.
def f(x):
return named_mul(x, x)
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
ans = jax.jit(f)(x)
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
self.assertEqual(ans.names, expected.names)
self.assertAllClose(ans.data, expected.data)
# And unregister NamedArray without exceptions.
batching.unregister_vmappable(NamedArray)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -37,18 +37,41 @@ def call_kernel(
m, n = grid
return jnp.concatenate([
jnp.concatenate([
kernel(i, j, *args) for j in range(n)], axis=1)
kernel((i, j), *args) for j in range(n)], axis=1)
for i in range(m)], axis=0)
def uniform_kernel(i: int, j: int, total_size, block_size, tile_size):
"""Uniform random sampling kernel function."""
global_key = jax.random.key(0)
keys = blocked_sampler.blocked_fold_in(global_key,
def call_kernel_3d(
kernel,
grid: tuple[int, int],
*args
):
"""Calls a kernel over a 3D grid and concatenates results to a single array."""
depth, rows, cols = grid
return jnp.concatenate([
jnp.concatenate([
jnp.concatenate([
jnp.array(kernel((i, j, k), *args))
for k in range(cols)], axis=2)
for j in range(rows)], axis=1)
for i in range(depth)], axis=0)
def blocked_fold_in(block_index, key, total_size, block_size, tile_size):
"""Folds in block_index into global_key."""
return blocked_sampler.blocked_fold_in(key,
total_size=total_size,
block_size=block_size,
tile_size=tile_size,
block_index=(i, j))
block_index=block_index)
def uniform_kernel(block_index, key, total_size, block_size, tile_size):
"""Uniform random sampling kernel function."""
keys = blocked_fold_in(block_index, key,
total_size=total_size,
block_size=block_size,
tile_size=tile_size)
return blocked_sampler.sample_block(jax.random.uniform,
keys,
block_size=block_size,
@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
)
def test_block_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size, transpose_grid):
global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel(
uniform_kernel, grid_a, transpose_grid,
uniform_kernel, grid_a, transpose_grid, global_key,
total_size, block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel(
uniform_kernel, grid_b, transpose_grid,
uniform_kernel, grid_b, transpose_grid, global_key,
total_size, block_size_b, tile_size)
np.testing.assert_array_equal(result_a, result_b)
class BlockedFoldInTest(jtu.JaxTestCase):
@parameterized.named_parameters(
# Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
# as expected. Specifically, blocked key folding does not depend on the total
# size of the tensor, but only the total number of tiles.
# Using a 3D grid (with very large inner dimensions) triggers an overflow in a
# previous implementation of blocked_fold_in.
dict(testcase_name='4096x512_vs_1024x2048',
total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512),
block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)),
)
def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size):
global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel_3d(
blocked_fold_in, grid_a, global_key, total_size,
block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel_3d(
blocked_fold_in, grid_b, global_key, total_size,
block_size_b, tile_size)
np.testing.assert_array_equal(jax.random.key_data(result_a),
jax.random.key_data(result_b))
if __name__ == "__main__":
absltest.main()

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for release_backend_clients."""
from absl.testing import absltest
import jax

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for --debug_nans."""
from absl.testing import absltest
import jax

View File

@ -20,12 +20,14 @@ from jax._src import config
from jax._src import error_check
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
JaxValueError = error_check.JaxValueError
config.parse_flags_with_absl()
jtu.request_cpu_devices(4)
@jtu.with_config(jax_check_tracer_leaks=True)
@ -190,6 +192,23 @@ class ErrorCheckTests(jtu.JaxTestCase):
):
jax.jit(error_check.raise_if_error)()
@parameterized.product(jit=[True, False])
@jtu.with_user_mesh((2, 2), ("x", "y"))
def test_error_check_explicit_mode(self, mesh, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1
if jit:
f = jax.jit(f)
sharding = NamedSharding(mesh, P("x", "y"))
x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding)
with error_check.error_checking_context():
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for garbage allocation guard."""
import gc
import weakref

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax.numpy.ufunc and its methods."""
import itertools
from functools import partial

View File

@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase):
x = lax.optimization_barrier((2, 3))
self.assertEqual((2, 3), x)
def test_optimization_barrier_autodiff(self):
def f(x):
y = 1. * x
x, y = lax.optimization_barrier((x, y))
z = 2. * x
return y + z
g = jax.grad(f)(5.) # doesn't crash
self.assertAllClose(g, 3., check_dtypes=False)
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the LAPAX linear algebra module."""
from functools import partial
import itertools
from typing import Iterator

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mesh utils."""
import collections
from collections.abc import Sequence

View File

@ -74,6 +74,37 @@ class LayoutInferenceTest(parameterized.TestCase):
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
def test_infer_strided_layout_from_shape_cast(self):
shape = (16, 8)
elt_type = ir.BF16Type.get()
src_type = ir.VectorType.get(shape, elt_type)
dst_type = ir.VectorType.get([*reversed(shape)], elt_type)
op = None
def body(x):
nonlocal op
op = vector.ShapeCastOp(dst_type, x)
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(src_type)(body)
mgpu.infer_layout(self.module)
in_layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(src_type)
)
out_layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(dst_type)
)
self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout])
self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout])
# Ensure that we can recover the original layout.
del op.attributes["in_layouts"]
mgpu.infer_layout(self.module)
self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout])
def test_infer_splat_layout_for_splat_constants(self):
shape = (16, 8)
elt_type = ir.BF16Type.get()

View File

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Mosaic GPU DSL functions and utilities."""
from collections.abc import Sequence
import contextlib
import dataclasses
import enum
import itertools
@ -84,6 +84,20 @@ def mlir_sum(elems):
return total
@contextlib.contextmanager
def get_sass():
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
try:
with jtu.capture_stdout() as output:
yield output
finally:
if prev_dump is not None:
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
else:
del os.environ["MOSAIC_GPU_DUMP_SASS"]
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
index = ir.IndexType.get()
thread_id = gpu.thread_id(gpu.Dimension.x)
@ -519,14 +533,38 @@ class WGMMALayoutTest(TestCase):
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.named_parameters(
("bf16_i8", jnp.bfloat16, jnp.int8),
("i8_bf16", jnp.int8, jnp.bfloat16),
("i8_i8", jnp.int8, jnp.int8),
("i4_i4", jnp.int4, jnp.int4),
("i4_bf16", jnp.int4, jnp.bfloat16),
@parameterized.parameters(jnp.int8, jnp.int16, jnp.int32)
def test_sub_byte_conversion(self, jax_dtype_to):
jax_dtype_from = jnp.int4
def kernel(ctx, inp, out, smem):
del ctx # Unused.
smem_inp, smem_out = smem
copy(inp, smem_inp, swizzle=16)
t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16)
t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True)
t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
x = self.prng.integers(
low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32
).astype(jax_dtype_from)
y = x.astype(jax_dtype_to)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
np.testing.assert_array_equal(f(x), y)
@parameterized.product(
jax_dtype_from_to=(
(jnp.int8, jnp.bfloat16),
(jnp.int4, jnp.bfloat16),
),
layout=(
fa.WGMMA_LAYOUT,
fa.WGMMA_LAYOUT_UPCAST_2X,
fa.WGMMA_LAYOUT_UPCAST_4X,
),
)
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
def test_optimized_conversion(self, jax_dtype_from_to, layout):
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
m = 128
@ -539,7 +577,7 @@ class WGMMALayoutTest(TestCase):
smem_from,
swizzle=128,
is_signed=utils.is_signed(jax_dtype_from),
layout=fa._tiled_wgmma_layout((m, n))
layout=layout,
)
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
t.store_tiled(smem_to, swizzle=128)
@ -2175,19 +2213,11 @@ class LayoutTest(TestCase):
.transpose(0, 2, 1, 3)
)
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
try:
with jtu.capture_stdout() as get_sass:
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
[expected, expected, mgpu.TMABarrier()],
)(expected)
finally:
if prev_dump is not None:
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
else:
del os.environ["MOSAIC_GPU_DUMP_SASS"]
with get_sass() as sass:
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
[expected, expected, mgpu.TMABarrier()],
)(expected)
np.testing.assert_array_equal(iota, expected)
# Verify that we don't use too many registers for the transfers.
@ -2200,7 +2230,7 @@ class LayoutTest(TestCase):
expected_regs //= 2
for instr in ("STS", "LDS"):
with self.subTest(instr + " count"):
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
addrs = re.findall(instr + r".* \[(.*)\]", sass())
def get_reg(addr):
if (pos := addr.find("+")) != -1:
return addr[:pos]
@ -2214,13 +2244,13 @@ class LayoutTest(TestCase):
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
m, n = 128, col_tiling * 2
tiling = (64, col_tiling)
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
layout = fa.WGMMA_LAYOUT_UPCAST_2X
def kernel(ctx, in_, out, smems):
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
smem_in, swizzle=swizzle, is_signed=True, layout=layout
)
t.store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
@ -2275,6 +2305,61 @@ class LayoutTest(TestCase):
)(x)
np.testing.assert_array_equal(y, y_ref)
@parameterized.parameters(
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1),
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1),
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1),
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
)
def test_upcast_to_wgmma(
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
):
in_dtype = jnp.dtype(in_dtype)
out_dtype = jnp.dtype(jnp.int16)
out_dtype_mlir = utils.dtype_to_ir_type(out_dtype)
swizzle = 128
in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits
in_tiling = (8, in_col_tiling)
out_col_tiling = swizzle // out_dtype.itemsize
out_tiling = (8, out_col_tiling)
m, n = 128, in_col_tiling * 2
regs_per_thread = None
def kernel(ctx, in_, out, smems):
nonlocal regs_per_thread
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=start_layout
)
regs_per_thread = t.registers.size
t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True)
t = t.to_layout(end_layout)
t = t.astype(out_dtype_mlir, is_signed=True)
t.store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
ctx.await_async_copy(0)
def tile(x, tiling):
return x.reshape(
x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1]
).transpose(0, 2, 1, 3)
in_iinfo = jnp.iinfo(in_dtype)
x = jax.random.randint(
jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32
).astype(in_dtype)
xt = tile(x, in_tiling)
y = x.astype(out_dtype)
yt = tile(y, out_tiling)
f = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
)
with get_sass() as sass:
yt_kernel = f(xt)
np.testing.assert_array_equal(yt_kernel, yt)
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
@dataclasses.dataclass(frozen=True)
class Tile:

View File

@ -25,8 +25,11 @@ from jax._src.interpreters import mlir as mlir_interpreter
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import vector
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import fragmented_array as fa
from jax.experimental.mosaic.gpu import inference_utils
from jax.experimental.mosaic.gpu import layouts as layouts_lib
import numpy as np
@ -162,6 +165,259 @@ class TransformInferenceTest(parameterized.TestCase):
)
self.assertEmpty(inference_utils.out_transforms(async_store_op))
def test_infer_transforms_for_vector_load_op_derives_from_destination(self):
vector_load_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
def body(smem_ref):
nonlocal vector_load_op
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
vector_load_op = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
)
with ir.InsertionPoint(self.module.body):
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
func.FuncOp.from_py_func(smem_ty)(body)
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
mgpu.infer_transforms(self.module)
expected_transforms = ir.ArrayAttr.get([
mgpu.dialect.TileTransformAttr.get((8, 64)),
mgpu.dialect.SwizzleTransformAttr.get(128),
])
self.assertSequenceEqual(
inference_utils.in_transforms(vector_load_op), [expected_transforms]
)
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
def test_infer_transforms_for_vector_load_op_derives_from_source(self):
vector_load_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
def body(smem_ref):
nonlocal vector_load_op
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
vector_load_op = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
)
with ir.InsertionPoint(self.module.body):
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
)
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
mgpu.infer_transforms(self.module)
self.assertSequenceEqual(
inference_utils.in_transforms(vector_load_op), [transforms]
)
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self):
vector_load_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
def body(smem_ref):
nonlocal vector_load_op
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
vector_load_op = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
)
with ir.InsertionPoint(self.module.body):
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
mgpu.infer_transforms(self.module)
def test_infer_transforms_for_vector_store_op_derives_from_destination(self):
vector_store_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
def body(smem_ref, value_to_store):
nonlocal vector_store_op
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
vector_store_op = vector.StoreOp(
value_to_store, smem_ref, [zero] * len(shape)
)
with ir.InsertionPoint(self.module.body):
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
value_ty = ir.VectorType.get(shape, elt_ty)
func.FuncOp.from_py_func(smem_ty, value_ty)(body)
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
mgpu.infer_transforms(self.module)
expected_transforms = ir.ArrayAttr.get([
mgpu.dialect.TileTransformAttr.get((8, 64)),
mgpu.dialect.SwizzleTransformAttr.get(128),
])
self.assertSequenceEqual(
inference_utils.in_transforms(vector_store_op), [expected_transforms]
)
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
def test_infer_transforms_for_vector_store_op_derives_from_source(self):
vector_store_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
def body(smem_ref, value_to_store):
nonlocal vector_store_op
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
vector_store_op = vector.StoreOp(
value_to_store, smem_ref, [zero] * len(shape)
)
with ir.InsertionPoint(self.module.body):
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
value_ty = ir.VectorType.get(shape, elt_ty)
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
)
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
mgpu.infer_transforms(self.module)
self.assertSequenceEqual(
inference_utils.in_transforms(vector_store_op), [transforms]
)
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self):
vector_store_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
def body(smem_ref, value_to_store):
nonlocal vector_store_op
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
vector_store_op = vector.StoreOp(
value_to_store, smem_ref, [zero] * len(shape)
)
with ir.InsertionPoint(self.module.body):
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
value_ty = ir.VectorType.get(shape, elt_ty)
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
mgpu.infer_transforms(self.module)
def test_infer_transforms_for_slice_smem_op_derives_from_user(self):
slice_smem_op = vector_load_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
def body(offset):
nonlocal slice_smem_op, vector_load_op
slice_smem_op = mgpu.dialect.SliceSMEMOp(
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
)
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
load_offsets = [zero] * len(shape)
vector_load_op = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
)
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
mgpu.infer_transforms(self.module)
expected_transforms = ir.ArrayAttr.get([
mgpu.dialect.TileTransformAttr.get((8, 64)),
mgpu.dialect.SwizzleTransformAttr.get(128),
])
self.assertEmpty(inference_utils.in_transforms(slice_smem_op))
self.assertSequenceEqual(
inference_utils.out_transforms(slice_smem_op), [expected_transforms]
)
def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self):
slice_smem_op = vector_load_op1 = vector_load_op2 = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
def body(offset):
nonlocal slice_smem_op, vector_load_op1, vector_load_op2
slice_smem_op = mgpu.dialect.SliceSMEMOp(
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
)
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
load_offsets = [zero] * len(shape)
vector_load_op1 = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
)
vector_load_op2 = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
)
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
)
vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get(
[ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])]
)
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
mgpu.infer_transforms(self.module)
if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Mosaic GPU CUPTI-based profiler."""
from absl.testing import absltest, parameterized
import jax

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nn module."""
import collections
from functools import partial
import itertools

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the optimizers module."""
import functools
from absl.testing import absltest

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pull block spec."""
from absl.testing import absltest
from absl.testing import parameterized
import jax

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Pallas indexing logic and abstractions."""
from __future__ import annotations
import sys
import unittest

View File

@ -185,7 +185,7 @@ class PallasCallTest(PallasTest):
np.testing.assert_array_equal(kernel(x, y), x + y[0])
@parameterized.product(
shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics]
shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_reduce_sum(self, shape, thread_semantics):
@functools.partial(

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for common JAX operations within pallas_call."""
from collections.abc import Sequence
import functools
import itertools

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Pallas error handling."""
import functools
import traceback

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TPU specific operations within pallas_call."""
import functools
import math

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for distributed pallas TPU operations."""
import functools
import os
import tempfile

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for random ops in Pallas + Mosaic."""
from absl.testing import absltest
from absl.testing import parameterized

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Pallas mesh API."""
import functools
from absl.testing import absltest
import jax

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for splash_attention."""
from __future__ import annotations
from collections.abc import Callable

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for splash_attention_masks."""
from __future__ import annotations
from absl.testing import absltest

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for interoperability between JAX and pickling libraries."""
import pickle
import unittest

View File

@ -6138,6 +6138,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
{AxisType.Auto: ('x',)})
@jtu.with_user_mesh((2,), 'x')
def test_device_put_use_mesh(self, mesh):
out = jax.device_put(np.arange(8), P('x'))
self.assertArraysEqual(out, np.arange(8))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
def test_device_put_no_use_mesh_error(self):
with self.assertRaisesRegex(
ValueError,
'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is'
' passed to device_put'):
jax.device_put(np.arange(8), P('x'))
@jtu.with_user_mesh((2,), 'x')
def test_inputs_different_context(self, mesh):
np_inp = np.arange(16).reshape(8, 2)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
"""Tests for the library of QDWH-based polar decomposition."""
import functools
from absl.testing import absltest

View File

@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
self._CheckChiSquared(samples, pmf=pmf)
@jtu.sample_product(
logits_shape=[(7,), (8, 9), (10, 11, 12)],
prefix_shape=[(2,), (3, 4), (5, 6)],
)
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
key = random.key(0)
key, subkey = random.split(key)
logits = random.normal(subkey, logits_shape)
key, subkey = random.split(key)
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
dists_shape = tuple(np.delete(logits_shape, axis))
n_categories = logits_shape[axis]
shape = prefix_shape + dists_shape
prefix_size = math.prod(prefix_shape)
if n_categories < prefix_size:
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
else:
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
self.assertEqual(output.shape, shape)
assert (0 <= output).all()
assert (output < n_categories).all()
flat = output.reshape((prefix_size, math.prod(dists_shape)))
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
assert (counts <= 1).all()
def testBernoulliShape(self):
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the shape-polymorphic export."""
from __future__ import annotations

View File

@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for stack."""
from absl.testing import absltest
import jax

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Stax library."""
from absl.testing import absltest
import numpy as np

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
"""Tests for the library of QDWH-based singular value decomposition."""
import functools
import jax

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transfer guards."""
import contextlib
import pickle

View File

@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase):
util.safe_zip((), range(3))
class Node:
def __init__(self, parents):
self.parents = parents
class TopologicalSortTest(jtu.JaxTestCase):
def _check_topological_sort(self, nodes, order):
self.assertEqual(sorted(nodes, key=id), sorted(order, key=id))
visited = set()
for node in nodes:
self.assertTrue(all(id(parent) in visited for parent in node.parents))
visited.add(id(node))
def test_basic(self):
a = Node([])
b = Node([a])
c = Node([a])
d = Node([a, c])
e = Node([b, c])
out = util.toposort([a, d, e])
self._check_topological_sort([a, b, c, d, e], out)
def test_stick(self):
a = Node([])
b = Node([a])
c = Node([b])
d = Node([c])
e = Node([d])
out = util.toposort([e])
self._check_topological_sort([a, b, c, d, e], out)
def test_diamonds(self):
a = Node([])
b = Node([a])
c = Node([a])
d = Node([b, c])
e = Node([d])
f = Node([d])
g = Node([e, f])
out = util.toposort([g])
self._check_topological_sort([a, b, c, d, e, f, g], out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "4c4aa96f9ffec4bb963b50c50192aeab4da9dc4a"
XLA_SHA256 = "c373e52b2f8b4175c69e99e636ad64b3bcf33fb44d1b7ad6ef8f4162c9052af8"
XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438"
XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497"
def repo():
tf_http_archive(