mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #294 from ROCm/ci-upstream-sync-151_1
CI: 03/18/25 upstream sync
This commit is contained in:
commit
c46b4fc02b
5
.github/workflows/pytest_cpu.yml
vendored
5
.github/workflows/pytest_cpu.yml
vendored
@ -118,6 +118,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$JAXCI_PYTHON -m pip install uv~=0.5.30
|
$JAXCI_PYTHON -m pip install uv~=0.5.30
|
||||||
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
|
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
|
||||||
|
|
||||||
|
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
|
||||||
|
if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then
|
||||||
|
$JAXCI_PYTHON -m uv pip install numpy~=2.1.0
|
||||||
|
fi
|
||||||
# Halt for testing
|
# Halt for testing
|
||||||
- name: Wait For Connection
|
- name: Wait For Connection
|
||||||
uses: google-ml-infra/actions/ci_connection@main
|
uses: google-ml-infra/actions/ci_connection@main
|
||||||
|
3
.github/workflows/pytest_cuda.yml
vendored
3
.github/workflows/pytest_cuda.yml
vendored
@ -54,7 +54,8 @@ jobs:
|
|||||||
runs-on: ${{ inputs.runner }}
|
runs-on: ${{ inputs.runner }}
|
||||||
# TODO: Update to the generic ML ecosystem test containers when they are ready.
|
# TODO: Update to the generic ML ecosystem test containers when they are ready.
|
||||||
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
|
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
|
||||||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
|
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
|
||||||
|
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
|
||||||
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
|
24
.github/workflows/wheel_tests_continuous.yml
vendored
24
.github/workflows/wheel_tests_continuous.yml
vendored
@ -110,18 +110,30 @@ jobs:
|
|||||||
fail-fast: false # don't cancel all jobs on failure
|
fail-fast: false # don't cancel all jobs on failure
|
||||||
matrix:
|
matrix:
|
||||||
# Python values need to match the matrix stategy in the artifact build jobs above
|
# Python values need to match the matrix stategy in the artifact build jobs above
|
||||||
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
|
# See exlusions for what is fully tested
|
||||||
|
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
|
||||||
python: ["3.10",]
|
python: ["3.10",]
|
||||||
cuda: ["12.3", "12.1"]
|
cuda: ["12.1","12.3","12.8"]
|
||||||
enable-x64: [1, 0]
|
enable-x64: [1, 0]
|
||||||
exclude:
|
exclude:
|
||||||
# Run only a single configuration on H100 to save resources
|
# L4 does not run on cuda 12.8 but tests other configs
|
||||||
|
- runner: "linux-x86-g2-48-l4-4gpu"
|
||||||
|
cuda: "12.8"
|
||||||
|
# H100 runs only a single config, CUDA 12.3 Enable x64 1
|
||||||
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||||
|
cuda: "12.8"
|
||||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||||
python: "3.10"
|
|
||||||
cuda: "12.1"
|
cuda: "12.1"
|
||||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||||
python: "3.10"
|
enable-x64: "0"
|
||||||
enable-x64: 0
|
# B200 runs only a single config, CUDA 12.8 Enable x64 1
|
||||||
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||||
|
enable-x64: "0"
|
||||||
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||||
|
cuda: "12.1"
|
||||||
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||||
|
cuda: "12.3"
|
||||||
|
|
||||||
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
||||||
with:
|
with:
|
||||||
runner: ${{ matrix.runner }}
|
runner: ${{ matrix.runner }}
|
||||||
|
@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
||||||
true, matching the current behavior. If set to false, JAX does not need to
|
true, matching the current behavior. If set to false, JAX does not need to
|
||||||
emit code clamping negative indices, which improves code size.
|
emit code clamping negative indices, which improves code size.
|
||||||
|
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
|
||||||
|
without replacement.
|
||||||
|
|
||||||
## jax 0.5.2 (Mar 4, 2025)
|
## jax 0.5.2 (Mar 4, 2025)
|
||||||
|
|
||||||
|
@ -19,6 +19,3 @@ matplotlib~=3.8.4; python_version=="3.10"
|
|||||||
matplotlib; python_version>="3.11"
|
matplotlib; python_version>="3.11"
|
||||||
opt-einsum
|
opt-einsum
|
||||||
auditwheel
|
auditwheel
|
||||||
|
|
||||||
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
|
|
||||||
numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64"
|
|
||||||
|
@ -49,13 +49,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 7,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"id": "hVi6mApuVw3r"
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "hVi6mApuVw3r",
|
|
||||||
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
|
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -84,13 +80,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 8,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "mzDIDvj7Vw0k",
|
"id": "mzDIDvj7Vw0k",
|
||||||
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
|
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -119,13 +115,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 9,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "IyPx_-IBVwxr",
|
"id": "IyPx_-IBVwxr",
|
||||||
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
|
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -141,7 +137,7 @@
|
|||||||
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
|
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -172,13 +168,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 10,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "NO2ulM_QW7a8",
|
"id": "NO2ulM_QW7a8",
|
||||||
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
|
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -208,13 +204,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 11,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "1-TzmA0AXCAf",
|
"id": "1-TzmA0AXCAf",
|
||||||
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
|
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -256,13 +252,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 12,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "Gy7ABds3XND3",
|
"id": "Gy7ABds3XND3",
|
||||||
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
|
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -297,13 +293,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 13,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "grCcotr-XQjY",
|
"id": "grCcotr-XQjY",
|
||||||
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
|
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -324,7 +320,7 @@
|
|||||||
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
|
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -460,13 +456,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 14,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "fpFEaMBcXsJG",
|
"id": "fpFEaMBcXsJG",
|
||||||
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
|
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -479,13 +475,6 @@
|
|||||||
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
|
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
|
||||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@ -550,13 +539,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 15,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "geptWrdYX0OM",
|
"id": "geptWrdYX0OM",
|
||||||
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
|
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -588,7 +577,88 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "AQQjzUeGX4P6"
|
"id": "LZWjgiMZ7uSS"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "IVzPSkp77uCF",
|
||||||
|
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
|
||||||
|
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
|
||||||
|
"\n",
|
||||||
|
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
|
||||||
|
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
|
||||||
|
"\n",
|
||||||
|
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
|
||||||
|
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
|
||||||
|
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
|
||||||
|
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import functools\n",
|
||||||
|
"\n",
|
||||||
|
"@functools.partial(auto_axes, axes='X')\n",
|
||||||
|
"def g(y):\n",
|
||||||
|
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
|
||||||
|
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
|
||||||
|
" return y * 2\n",
|
||||||
|
"\n",
|
||||||
|
"@jax.jit\n",
|
||||||
|
"def f(arr1):\n",
|
||||||
|
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
|
||||||
|
" x = jnp.sin(arr1)\n",
|
||||||
|
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
|
||||||
|
"\n",
|
||||||
|
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
|
||||||
|
"\n",
|
||||||
|
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
|
||||||
|
" return z + 1\n",
|
||||||
|
"\n",
|
||||||
|
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
|
||||||
|
"f(some_x)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "_3sfJjRq8w9f"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "sJcWbfAh7UcO"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Concrete array shardings can mention `Auto` mesh axis\n",
|
"## Concrete array shardings can mention `Auto` mesh axis\n",
|
||||||
@ -606,7 +676,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 25,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@ -708,5 +778,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 4
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
|
@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you
|
|||||||
find something that doesn't work!
|
find something that doesn't work!
|
||||||
|
|
||||||
```{code-cell} ipython3
|
```{code-cell} ipython3
|
||||||
---
|
:id: hVi6mApuVw3r
|
||||||
colab:
|
|
||||||
base_uri: https://localhost:8080/
|
|
||||||
id: hVi6mApuVw3r
|
|
||||||
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
|
|
||||||
---
|
|
||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -79,7 +75,7 @@ scalar) using `jax.typeof`:
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: mzDIDvj7Vw0k
|
id: mzDIDvj7Vw0k
|
||||||
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
|
outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a
|
||||||
---
|
---
|
||||||
some_array = np.arange(8)
|
some_array = np.arange(8)
|
||||||
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
|
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
|
||||||
@ -96,7 +92,7 @@ under a jit).
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: IyPx_-IBVwxr
|
id: IyPx_-IBVwxr
|
||||||
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
|
outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb
|
||||||
---
|
---
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def foo(x):
|
def foo(x):
|
||||||
@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: NO2ulM_QW7a8
|
id: NO2ulM_QW7a8
|
||||||
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
|
outputId: d888371b-080e-4bff-be5d-ea56beda3aac
|
||||||
---
|
---
|
||||||
mesh = jax.make_mesh((2, 4), ("X", "Y"),
|
mesh = jax.make_mesh((2, 4), ("X", "Y"),
|
||||||
axis_types=(AxisType.Explicit, AxisType.Explicit))
|
axis_types=(AxisType.Explicit, AxisType.Explicit))
|
||||||
@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`:
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: 1-TzmA0AXCAf
|
id: 1-TzmA0AXCAf
|
||||||
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
|
outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2
|
||||||
---
|
---
|
||||||
replicated_array = np.arange(8).reshape(4, 2)
|
replicated_array = np.arange(8).reshape(4, 2)
|
||||||
sharded_array = reshard(replicated_array, P("X", None))
|
sharded_array = reshard(replicated_array, P("X", None))
|
||||||
@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: Gy7ABds3XND3
|
id: Gy7ABds3XND3
|
||||||
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
|
outputId: 0d72dad2-381a-4e96-f771-40d705da1376
|
||||||
---
|
---
|
||||||
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
|
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
|
||||||
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
|
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
|
||||||
@ -184,7 +180,7 @@ We can do the same type querying under a jit:
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: grCcotr-XQjY
|
id: grCcotr-XQjY
|
||||||
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
|
outputId: c2db656c-809f-49a6-c948-629d6420360c
|
||||||
---
|
---
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def add_arrays(x, y):
|
def add_arrays(x, y):
|
||||||
@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows:
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: fpFEaMBcXsJG
|
id: fpFEaMBcXsJG
|
||||||
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
|
outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef
|
||||||
---
|
---
|
||||||
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
|
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
|
||||||
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
|
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
|
||||||
@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with
|
|||||||
colab:
|
colab:
|
||||||
base_uri: https://localhost:8080/
|
base_uri: https://localhost:8080/
|
||||||
id: geptWrdYX0OM
|
id: geptWrdYX0OM
|
||||||
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
|
outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f
|
||||||
---
|
---
|
||||||
print(f"Current mesh is: {get_abstract_mesh()}")
|
print(f"Current mesh is: {get_abstract_mesh()}")
|
||||||
```
|
```
|
||||||
@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only
|
|||||||
mention _explicit_ mesh axes and collective operations like `psum` can only
|
mention _explicit_ mesh axes and collective operations like `psum` can only
|
||||||
mention _manual_ mesh axes.
|
mention _manual_ mesh axes.
|
||||||
|
|
||||||
+++ {"id": "AQQjzUeGX4P6"}
|
+++ {"id": "LZWjgiMZ7uSS"}
|
||||||
|
|
||||||
|
You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:
|
||||||
|
|
||||||
|
```{code-cell} ipython3
|
||||||
|
---
|
||||||
|
colab:
|
||||||
|
base_uri: https://localhost:8080/
|
||||||
|
id: IVzPSkp77uCF
|
||||||
|
outputId: db80a604-98ac-4343-8677-23729adf7ffc
|
||||||
|
---
|
||||||
|
import functools
|
||||||
|
|
||||||
|
@functools.partial(auto_axes, axes='X')
|
||||||
|
def g(y):
|
||||||
|
print(f'mesh inside g: {get_abstract_mesh()}')
|
||||||
|
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
|
||||||
|
return y * 2
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def f(arr1):
|
||||||
|
print(f'mesh inside f: {get_abstract_mesh()}')
|
||||||
|
x = jnp.sin(arr1)
|
||||||
|
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
|
||||||
|
|
||||||
|
z = g(x, out_shardings=P("X", "Y"))
|
||||||
|
|
||||||
|
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
|
||||||
|
return z + 1
|
||||||
|
|
||||||
|
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
|
||||||
|
f(some_x)
|
||||||
|
```
|
||||||
|
|
||||||
|
+++ {"id": "_3sfJjRq8w9f"}
|
||||||
|
|
||||||
|
As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.
|
||||||
|
|
||||||
|
+++ {"id": "sJcWbfAh7UcO"}
|
||||||
|
|
||||||
## Concrete array shardings can mention `Auto` mesh axis
|
## Concrete array shardings can mention `Auto` mesh axis
|
||||||
|
|
||||||
|
@ -299,7 +299,7 @@
|
|||||||
" ):\n",
|
" ):\n",
|
||||||
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
|
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
|
||||||
" del idxs_k_ref\n",
|
" del idxs_k_ref\n",
|
||||||
" blk_idx = pl.program_id(0)\n",
|
" blk_idx = pl.program_id(1)\n",
|
||||||
" is_start = blk_idx == 0\n",
|
" is_start = blk_idx == 0\n",
|
||||||
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
|
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
|
||||||
" @pl.when(is_start | changed_blocks)\n",
|
" @pl.when(is_start | changed_blocks)\n",
|
||||||
@ -314,13 +314,13 @@
|
|||||||
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||||
" del j, blk_idxs_i, blk_idxs_k\n",
|
" del j, blk_idxs_i, blk_idxs_k\n",
|
||||||
" return (blk_idx, 0, 0)\n",
|
" return (blk_idx, 0, 0)\n",
|
||||||
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||||
" del blk_idxs_i\n",
|
" del blk_idxs_i\n",
|
||||||
" return (blk_idxs_k[blk_idx], j)\n",
|
" return (blk_idxs_k[blk_idx], j)\n",
|
||||||
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||||
" del blk_idxs_k\n",
|
" del blk_idxs_k\n",
|
||||||
" return (blk_idxs_i[blk_idx], j)\n",
|
" return (blk_idxs_i[blk_idx], j)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -335,7 +335,7 @@
|
|||||||
" num_scalar_prefetch=2,\n",
|
" num_scalar_prefetch=2,\n",
|
||||||
" # Note that while num_blocks is static here, Pallas does support\n",
|
" # Note that while num_blocks is static here, Pallas does support\n",
|
||||||
" # dynamic grid sizes.\n",
|
" # dynamic grid sizes.\n",
|
||||||
" grid=(num_blocks, N // blk_N),\n",
|
" grid=(N // blk_N, num_blocks),\n",
|
||||||
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
||||||
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
||||||
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
||||||
|
@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
|||||||
):
|
):
|
||||||
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
|
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
|
||||||
del idxs_k_ref
|
del idxs_k_ref
|
||||||
blk_idx = pl.program_id(0)
|
blk_idx = pl.program_id(1)
|
||||||
is_start = blk_idx == 0
|
is_start = blk_idx == 0
|
||||||
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
|
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
|
||||||
@pl.when(is_start | changed_blocks)
|
@pl.when(is_start | changed_blocks)
|
||||||
@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
|||||||
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
||||||
|
|
||||||
|
|
||||||
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||||
del j, blk_idxs_i, blk_idxs_k
|
del j, blk_idxs_i, blk_idxs_k
|
||||||
return (blk_idx, 0, 0)
|
return (blk_idx, 0, 0)
|
||||||
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||||
del blk_idxs_i
|
del blk_idxs_i
|
||||||
return (blk_idxs_k[blk_idx], j)
|
return (blk_idxs_k[blk_idx], j)
|
||||||
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||||
del blk_idxs_k
|
del blk_idxs_k
|
||||||
return (blk_idxs_i[blk_idx], j)
|
return (blk_idxs_i[blk_idx], j)
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
|
|||||||
num_scalar_prefetch=2,
|
num_scalar_prefetch=2,
|
||||||
# Note that while num_blocks is static here, Pallas does support
|
# Note that while num_blocks is static here, Pallas does support
|
||||||
# dynamic grid sizes.
|
# dynamic grid sizes.
|
||||||
grid=(num_blocks, N // blk_N),
|
grid=(N // blk_N, num_blocks),
|
||||||
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
||||||
pl.BlockSpec((blk_K, blk_N), y_map),
|
pl.BlockSpec((blk_K, blk_N), y_map),
|
||||||
# Placeholder for a zeros-array used by input_output_aliases.
|
# Placeholder for a zeros-array used by input_output_aliases.
|
||||||
|
@ -81,7 +81,7 @@ int main(int argc, char** argv) {
|
|||||||
xla::XlaComputation xla_computation(test_module_proto);
|
xla::XlaComputation xla_computation(test_module_proto);
|
||||||
xla::CompileOptions compile_options;
|
xla::CompileOptions compile_options;
|
||||||
std::unique_ptr<xla::PjRtLoadedExecutable> executable =
|
std::unique_ptr<xla::PjRtLoadedExecutable> executable =
|
||||||
client->Compile(xla_computation, compile_options).value();
|
client->CompileAndLoad(xla_computation, compile_options).value();
|
||||||
|
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
xla::Literal literal_x =
|
xla::Literal literal_x =
|
||||||
|
@ -799,7 +799,7 @@ pytype_strict_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# This target only supports sm_90 GPUs.
|
# This target only supports sm_90 GPUs.
|
||||||
py_library(
|
py_library_providing_imports_info(
|
||||||
name = "mosaic_gpu",
|
name = "mosaic_gpu",
|
||||||
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
||||||
visibility = [
|
visibility = [
|
||||||
@ -824,6 +824,7 @@ py_library(
|
|||||||
"//jaxlib/mlir:pass_manager",
|
"//jaxlib/mlir:pass_manager",
|
||||||
"//jaxlib/mlir:scf_dialect",
|
"//jaxlib/mlir:scf_dialect",
|
||||||
"//jaxlib/mlir:vector_dialect",
|
"//jaxlib/mlir:vector_dialect",
|
||||||
|
"//jaxlib/mosaic/python:gpu_dialect",
|
||||||
] + py_deps("absl/flags") + py_deps("numpy"),
|
] + py_deps("absl/flags") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,7 +67,9 @@ from jax._src.lib import jax_jit
|
|||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import pmap_lib
|
from jax._src.lib import pmap_lib
|
||||||
from jax._src.sharding import Sharding
|
from jax._src.sharding import Sharding
|
||||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
from jax._src.mesh import get_concrete_mesh
|
||||||
|
from jax._src.sharding_impls import (
|
||||||
|
PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding)
|
||||||
from jax._src.layout import Layout, AutoLayout
|
from jax._src.layout import Layout, AutoLayout
|
||||||
from jax._src.traceback_util import api_boundary
|
from jax._src.traceback_util import api_boundary
|
||||||
from jax._src import tree_util
|
from jax._src import tree_util
|
||||||
@ -2280,11 +2282,20 @@ def _check_sharding(aval, s):
|
|||||||
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
|
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
|
||||||
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
||||||
|
|
||||||
|
def pspec_to_sharding(val):
|
||||||
|
if isinstance(val, P):
|
||||||
|
mesh = get_concrete_mesh()
|
||||||
|
if mesh is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is"
|
||||||
|
" passed to device_put")
|
||||||
|
return NamedSharding(mesh, val)
|
||||||
|
return val
|
||||||
|
|
||||||
def device_put(
|
def device_put(
|
||||||
x,
|
x,
|
||||||
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
|
||||||
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
*, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
|
||||||
donate: bool | Any = False, may_alias: bool | None | Any = None):
|
donate: bool | Any = False, may_alias: bool | None | Any = None):
|
||||||
"""Transfers ``x`` to ``device``.
|
"""Transfers ``x`` to ``device``.
|
||||||
|
|
||||||
@ -2333,6 +2344,9 @@ def device_put(
|
|||||||
src_flat = flatten_axes("device_put source", treedef, src)
|
src_flat = flatten_axes("device_put source", treedef, src)
|
||||||
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
|
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
|
||||||
|
|
||||||
|
device_flat = map(pspec_to_sharding, device_flat)
|
||||||
|
src_flat = map(pspec_to_sharding, src_flat)
|
||||||
|
|
||||||
if isinstance(donate, bool):
|
if isinstance(donate, bool):
|
||||||
donate_flat = [donate] * len(x_flat)
|
donate_flat = [donate] * len(x_flat)
|
||||||
else:
|
else:
|
||||||
|
@ -28,17 +28,17 @@ class SampleFn(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def _compute_scalar_index(iteration_index: Sequence[int],
|
def _compute_tile_index(block_index: Sequence[int],
|
||||||
total_size: Shape,
|
total_size_in_blocks: Shape,
|
||||||
block_size: Shape,
|
block_size_in_tiles: Shape,
|
||||||
block_index: Sequence[int]) -> int:
|
tile_index_in_block: Sequence[int]) -> int:
|
||||||
ndims = len(iteration_index)
|
ndims = len(block_index)
|
||||||
dim_size = 1
|
dim_size = 1
|
||||||
total_idx = 0
|
total_idx = 0
|
||||||
for i in range(ndims-1, -1, -1):
|
for i in range(ndims-1, -1, -1):
|
||||||
dim_idx = block_index[i] + iteration_index[i] * block_size[i]
|
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
||||||
total_idx += dim_idx * dim_size
|
total_idx += dim_idx * dim_size
|
||||||
dim_size *= total_size[i]
|
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
|
||||||
return total_idx
|
return total_idx
|
||||||
|
|
||||||
|
|
||||||
@ -99,18 +99,23 @@ def blocked_fold_in(
|
|||||||
An N-dimensional nested list of keys required to sample the tiles
|
An N-dimensional nested list of keys required to sample the tiles
|
||||||
corresponding to the block specified by `block_index`.
|
corresponding to the block specified by `block_index`.
|
||||||
"""
|
"""
|
||||||
size_in_blocks = tuple(
|
block_size_in_tiles = tuple(
|
||||||
_shape // _element for _shape, _element in zip(block_size, tile_size))
|
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_size_in_blocks = tuple(
|
||||||
|
_shape // _element for _shape, _element in zip(total_size, block_size)
|
||||||
|
)
|
||||||
|
|
||||||
def _keygen_loop(axis, prefix):
|
def _keygen_loop(axis, prefix):
|
||||||
if axis == len(size_in_blocks):
|
if axis == len(block_size_in_tiles):
|
||||||
subtile_key = jax.random.fold_in(
|
subtile_key = jax.random.fold_in(
|
||||||
global_key, _compute_scalar_index(
|
global_key, _compute_tile_index(
|
||||||
block_index, total_size, size_in_blocks, prefix))
|
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
|
||||||
return subtile_key
|
return subtile_key
|
||||||
else:
|
else:
|
||||||
keys = []
|
keys = []
|
||||||
for i in range(size_in_blocks[axis]):
|
for i in range(block_size_in_tiles[axis]):
|
||||||
keys.append(_keygen_loop(axis+1, prefix+(i,)))
|
keys.append(_keygen_loop(axis+1, prefix+(i,)))
|
||||||
return keys
|
return keys
|
||||||
return _keygen_loop(0, tuple())
|
return _keygen_loop(0, tuple())
|
||||||
|
@ -446,7 +446,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
|||||||
if len(devices) == 1:
|
if len(devices) == 1:
|
||||||
# If we only have one device in our computation, we can construct a
|
# If we only have one device in our computation, we can construct a
|
||||||
# replicated HloSharding and call it right now.
|
# replicated HloSharding and call it right now.
|
||||||
_hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding())
|
_hlo_sharding_callback(sharding_impls.replicated_hlo_sharding)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)
|
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)
|
||||||
|
@ -466,6 +466,9 @@ def _device_put_sharding_impl(x, aval, device, copy):
|
|||||||
if not s.is_fully_addressable:
|
if not s.is_fully_addressable:
|
||||||
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
||||||
type(x) in array_types):
|
type(x) in array_types):
|
||||||
|
# TODO(emilyaf): Remove this condition when jit works when a sharding
|
||||||
|
# has no local devices.
|
||||||
|
if not config.enable_empty_arrays.value:
|
||||||
multihost_utils.assert_equal(
|
multihost_utils.assert_equal(
|
||||||
x, fail_message=(
|
x, fail_message=(
|
||||||
f"{type(x)} passed to device_put is not the same on each"
|
f"{type(x)} passed to device_put is not the same on each"
|
||||||
|
@ -14,13 +14,17 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import source_info_util
|
from jax._src import source_info_util
|
||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
|
import jax._src.mesh as mesh_lib
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||||
|
|
||||||
|
|
||||||
Traceback = source_info_util.Traceback
|
Traceback = source_info_util.Traceback
|
||||||
@ -54,17 +58,61 @@ _error_storage = _ErrorStorage()
|
|||||||
|
|
||||||
|
|
||||||
def _initialize_error_code_ref() -> None:
|
def _initialize_error_code_ref() -> None:
|
||||||
"""Initialize error_code_ref in the current thread."""
|
"""Initialize error_code_ref in the current thread.
|
||||||
|
|
||||||
|
The size of the error code array is determined by the mesh in the context. In
|
||||||
|
single-device environment, the array is a scalar. In multi-device
|
||||||
|
environment, the array has the same shape as the mesh.
|
||||||
|
"""
|
||||||
with core.eval_context():
|
with core.eval_context():
|
||||||
|
# Get mesh from the context.
|
||||||
|
mesh = mesh_lib.get_concrete_mesh()
|
||||||
|
|
||||||
|
if mesh is None: # single-device case.
|
||||||
error_code = jnp.uint32(_NO_ERROR)
|
error_code = jnp.uint32(_NO_ERROR)
|
||||||
|
|
||||||
|
else: # multi-device case.
|
||||||
|
sharding = NamedSharding(mesh, P(*mesh.axis_names))
|
||||||
|
error_code = jnp.full(
|
||||||
|
mesh.axis_sizes,
|
||||||
|
jnp.uint32(_NO_ERROR),
|
||||||
|
device=sharding,
|
||||||
|
)
|
||||||
|
|
||||||
_error_storage.ref = core.mutable_array(error_code)
|
_error_storage.ref = core.mutable_array(error_code)
|
||||||
|
|
||||||
|
|
||||||
def set_error_if(pred: jax.Array, msg: str) -> None:
|
class error_checking_context:
|
||||||
|
"""Redefine the error checking state based on the mesh in the context.
|
||||||
|
|
||||||
|
This context manager should be used when starting a multi-device
|
||||||
|
computation, and whenever the mesh is changed.
|
||||||
|
|
||||||
|
When exiting the context, the error checking state will be reset to the
|
||||||
|
original state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("old_ref",)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.old_ref = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.old_ref = _error_storage.ref
|
||||||
|
_initialize_error_code_ref()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
_error_storage.ref = self.old_ref
|
||||||
|
|
||||||
|
|
||||||
|
def set_error_if(pred: jax.Array, /, msg: str) -> None:
|
||||||
"""Set error if any element of pred is true.
|
"""Set error if any element of pred is true.
|
||||||
|
|
||||||
If the error is already set, the new error will be ignored. It will not
|
If the error is already set, the new error will be ignored. It will not
|
||||||
override the existing error.
|
override the existing error.
|
||||||
|
|
||||||
|
In auto mode, this function does not work under jit.
|
||||||
"""
|
"""
|
||||||
if _error_storage.ref is None:
|
if _error_storage.ref is None:
|
||||||
_initialize_error_code_ref()
|
_initialize_error_code_ref()
|
||||||
@ -76,7 +124,32 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
|
|||||||
new_error_code = jnp.uint32(len(_error_list))
|
new_error_code = jnp.uint32(len(_error_list))
|
||||||
_error_list.append((msg, traceback))
|
_error_list.append((msg, traceback))
|
||||||
|
|
||||||
|
out_sharding = core.typeof(_error_storage.ref).sharding
|
||||||
|
in_sharding: NamedSharding = core.typeof(pred).sharding
|
||||||
|
|
||||||
|
if out_sharding.mesh.shape_tuple == (): # single-device case.
|
||||||
pred = pred.any()
|
pred = pred.any()
|
||||||
|
else: # multi-device case.
|
||||||
|
has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types
|
||||||
|
if has_auto_axes:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Error checking in auto mode is not supported yet. Please use"
|
||||||
|
" explicit mode."
|
||||||
|
)
|
||||||
|
if out_sharding.mesh != in_sharding.mesh:
|
||||||
|
raise ValueError(
|
||||||
|
"The error code state and the predicate must be on the same mesh, "
|
||||||
|
f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. "
|
||||||
|
"Please use `with error_checking_context()` to redefine the error "
|
||||||
|
"code state based on the mesh."
|
||||||
|
)
|
||||||
|
pred = shard_map(
|
||||||
|
partial(jnp.any, keepdims=True),
|
||||||
|
mesh=out_sharding.mesh,
|
||||||
|
in_specs=in_sharding.spec,
|
||||||
|
out_specs=out_sharding.spec,
|
||||||
|
)(pred) # perform per-device reduction
|
||||||
|
|
||||||
error_code = _error_storage.ref[...]
|
error_code = _error_storage.ref[...]
|
||||||
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
|
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
|
||||||
error_code = jnp.where(should_update, new_error_code, error_code)
|
error_code = jnp.where(should_update, new_error_code, error_code)
|
||||||
@ -93,7 +166,7 @@ def raise_if_error() -> None:
|
|||||||
if _error_storage.ref is None: # if not initialized, do nothing
|
if _error_storage.ref is None: # if not initialized, do nothing
|
||||||
return
|
return
|
||||||
|
|
||||||
error_code = _error_storage.ref[...]
|
error_code = _error_storage.ref[...].min() # reduce to a single error code
|
||||||
if isinstance(error_code, core.Tracer):
|
if isinstance(error_code, core.Tracer):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"raise_if_error() should not be called within a traced context, such as"
|
"raise_if_error() should not be called within a traced context, such as"
|
||||||
@ -101,7 +174,11 @@ def raise_if_error() -> None:
|
|||||||
)
|
)
|
||||||
if error_code == jnp.uint32(_NO_ERROR):
|
if error_code == jnp.uint32(_NO_ERROR):
|
||||||
return
|
return
|
||||||
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
|
_error_storage.ref[...] = jnp.full(
|
||||||
|
_error_storage.ref.shape,
|
||||||
|
jnp.uint32(_NO_ERROR),
|
||||||
|
device=_error_storage.ref.sharding,
|
||||||
|
) # clear the error code
|
||||||
|
|
||||||
msg, traceback = _error_list[error_code]
|
msg, traceback = _error_list[error_code]
|
||||||
exc = JaxValueError(msg)
|
exc = JaxValueError(msg)
|
||||||
|
@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {}
|
|||||||
spec_types: set[type] = {JumbleAxis}
|
spec_types: set[type] = {JumbleAxis}
|
||||||
|
|
||||||
def unregister_vmappable(data_type: type) -> None:
|
def unregister_vmappable(data_type: type) -> None:
|
||||||
spec_type, axis_size_type = vmappables.pop(data_type)
|
_, axis_size_type = vmappables.pop(data_type)
|
||||||
spec_types.remove(spec_type)
|
|
||||||
del to_elt_handlers[data_type]
|
del to_elt_handlers[data_type]
|
||||||
del from_elt_handlers[data_type]
|
del from_elt_handlers[data_type]
|
||||||
if axis_size_type in make_iota_handlers:
|
if axis_size_type in make_iota_handlers:
|
||||||
del make_iota_handlers[axis_size_type]
|
del make_iota_handlers[axis_size_type]
|
||||||
|
global spec_types
|
||||||
|
spec_types = (
|
||||||
|
{JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
|
||||||
|
)
|
||||||
|
|
||||||
def is_vmappable(x: Any) -> bool:
|
def is_vmappable(x: Any) -> bool:
|
||||||
return type(x) is Jumble or type(x) in vmappables
|
return type(x) is Jumble or type(x) in vmappables
|
||||||
|
@ -797,7 +797,7 @@ def tracers_to_jaxpr(
|
|||||||
|
|
||||||
processed_eqn_ids = set()
|
processed_eqn_ids = set()
|
||||||
eqns: list[core.JaxprEqn] = []
|
eqns: list[core.JaxprEqn] = []
|
||||||
for t in toposort([*in_tracers, *out_tracers]):
|
for t in toposort((*in_tracers, *out_tracers)):
|
||||||
r = t.recipe
|
r = t.recipe
|
||||||
if isinstance(r, JaxprEqnRecipe):
|
if isinstance(r, JaxprEqnRecipe):
|
||||||
# TODO broadcast_in_dim can create a new tracer, not present in parents
|
# TODO broadcast_in_dim can create a new tracer, not present in parents
|
||||||
|
@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
|
|||||||
if (isinstance(x, array.ArrayImpl) and
|
if (isinstance(x, array.ArrayImpl) and
|
||||||
dispatch.is_single_device_sharding(x.sharding) and
|
dispatch.is_single_device_sharding(x.sharding) and
|
||||||
x.devices() == {d})]
|
x.devices() == {d})]
|
||||||
if len(bufs) == len(xs):
|
if len(bufs) == len(xs) > 0:
|
||||||
return array.ArrayImpl(
|
return array.ArrayImpl(
|
||||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||||
|
@ -1026,24 +1026,101 @@ def clz(x: ArrayLike) -> Array:
|
|||||||
r"""Elementwise count-leading-zeros."""
|
r"""Elementwise count-leading-zeros."""
|
||||||
return clz_p.bind(x)
|
return clz_p.bind(x)
|
||||||
|
|
||||||
|
@export
|
||||||
def add(x: ArrayLike, y: ArrayLike) -> Array:
|
def add(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise addition: :math:`x + y`."""
|
r"""Elementwise addition: :math:`x + y`.
|
||||||
|
|
||||||
|
This function lowers directly to the `stablehlo.add`_ operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||||
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||||
|
and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the sum
|
||||||
|
of each pair of broadcasted entries.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.add`: NumPy-style addition supporting inputs
|
||||||
|
with mixed dtypes and ranks.
|
||||||
|
|
||||||
|
.. _stablehlo.add: https://openxla.org/stablehlo/spec#add
|
||||||
|
"""
|
||||||
return add_p.bind(x, y)
|
return add_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def sub(x: ArrayLike, y: ArrayLike) -> Array:
|
def sub(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise subtraction: :math:`x - y`."""
|
r"""Elementwise subtraction: :math:`x - y`.
|
||||||
|
|
||||||
|
This function lowers directly to the `stablehlo.subtract`_ operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||||
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||||
|
and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the difference
|
||||||
|
of each pair of broadcasted entries.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.subtract`: NumPy-style subtraction supporting
|
||||||
|
inputs with mixed dtypes and ranks.
|
||||||
|
|
||||||
|
.. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract
|
||||||
|
"""
|
||||||
return sub_p.bind(x, y)
|
return sub_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def mul(x: ArrayLike, y: ArrayLike) -> Array:
|
def mul(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise multiplication: :math:`x \times y`."""
|
r"""Elementwise multiplication: :math:`x \times y`.
|
||||||
|
|
||||||
|
This function lowers directly to the `stablehlo.multiply`_ operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||||
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||||
|
and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the product
|
||||||
|
of each pair of broadcasted entries.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.multiply`: NumPy-style multiplication supporting
|
||||||
|
inputs with mixed dtypes and ranks.
|
||||||
|
|
||||||
|
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
||||||
|
"""
|
||||||
return mul_p.bind(x, y)
|
return mul_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def div(x: ArrayLike, y: ArrayLike) -> Array:
|
def div(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise division: :math:`x \over y`.
|
r"""Elementwise division: :math:`x \over y`.
|
||||||
|
|
||||||
Integer division overflow
|
This function lowers directly to the `stablehlo.divide`_ operation.
|
||||||
(division by zero or signed division of INT_SMIN with -1)
|
|
||||||
produces an implementation defined value.
|
Integer division overflow (division by zero or signed division of
|
||||||
|
INT_SMIN with -1) produces an implementation defined value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||||
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||||
|
and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the quotient
|
||||||
|
of each pair of broadcasted entries. For integer inputs, any fractional
|
||||||
|
part is discarded.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.divide`: NumPy-style true division supporting
|
||||||
|
inputs with mixed dtypes and ranks.
|
||||||
|
- :func:`jax.numpy.floor_divide`: NumPy-style floor division supporting
|
||||||
|
inputs with mixed dtypes and ranks.
|
||||||
|
|
||||||
|
.. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide
|
||||||
"""
|
"""
|
||||||
return div_p.bind(x, y)
|
return div_p.bind(x, y)
|
||||||
|
|
||||||
@ -8422,3 +8499,13 @@ mlir.register_lowering(optimization_barrier_p,
|
|||||||
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
|
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
|
||||||
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
|
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
|
||||||
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
|
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
|
||||||
|
|
||||||
|
def _opt_barrier_jvp(primals, tangents):
|
||||||
|
tangents = [ad.instantiate_zeros(t) for t in tangents]
|
||||||
|
return optimization_barrier(primals), optimization_barrier(tangents)
|
||||||
|
ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp
|
||||||
|
|
||||||
|
def _opt_barrier_transpose(cts, *primals):
|
||||||
|
cts = [ad.instantiate_zeros(ct) for ct in cts]
|
||||||
|
return optimization_barrier(cts)
|
||||||
|
ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose
|
||||||
|
@ -565,5 +565,5 @@ def use_concrete_mesh(mesh: Mesh | None):
|
|||||||
finally:
|
finally:
|
||||||
jax_config.device_context.set_local(prev_val)
|
jax_config.device_context.set_local(prev_val)
|
||||||
|
|
||||||
def get_concrete_mesh():
|
def get_concrete_mesh() -> Mesh | None:
|
||||||
return jax_config.device_context.value
|
return jax_config.device_context.value
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
"""Module for pallas-core functionality."""
|
"""Module for pallas-core functionality."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections
|
||||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
|
|||||||
return [], effs
|
return [], effs
|
||||||
|
|
||||||
|
|
||||||
|
class Mesh(Protocol):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backend(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> collections.OrderedDict[object, int]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
|
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule(
|
|||||||
in_avals,
|
in_avals,
|
||||||
out_avals,
|
out_avals,
|
||||||
*args,
|
*args,
|
||||||
grid,
|
mesh,
|
||||||
compiler_params,
|
compiler_params,
|
||||||
backend,
|
|
||||||
jaxpr,
|
jaxpr,
|
||||||
debug,
|
debug,
|
||||||
interpret,
|
interpret,
|
||||||
@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule(
|
|||||||
if isinstance(eff, state_types.WriteEffect)
|
if isinstance(eff, state_types.WriteEffect)
|
||||||
)
|
)
|
||||||
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
|
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
|
||||||
|
grid_spec = GridSpec(
|
||||||
|
grid=tuple(mesh.shape.items()),
|
||||||
|
in_specs=[any_spec] * len(in_avals),
|
||||||
|
out_specs=[any_spec] * len(modified_idxs),
|
||||||
|
)
|
||||||
from jax._src.pallas import pallas_call # Avoid circular dependency.
|
from jax._src.pallas import pallas_call # Avoid circular dependency.
|
||||||
outs = pallas_call.pallas_call(
|
outs = pallas_call._pallas_call(
|
||||||
body,
|
body,
|
||||||
name=name,
|
name=name,
|
||||||
out_shape=[in_avals[idx] for idx in modified_idxs],
|
out_shape=[in_avals[idx] for idx in modified_idxs],
|
||||||
in_specs=[any_spec] * len(in_avals),
|
|
||||||
out_specs=[any_spec] * len(modified_idxs),
|
|
||||||
input_output_aliases={
|
input_output_aliases={
|
||||||
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
|
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
|
||||||
},
|
},
|
||||||
grid=grid,
|
grid_spec=grid_spec,
|
||||||
|
mesh=mesh,
|
||||||
compiler_params=compiler_params,
|
compiler_params=compiler_params,
|
||||||
backend=backend,
|
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
cost_estimate=cost_estimate,
|
cost_estimate=cost_estimate,
|
||||||
|
@ -340,11 +340,12 @@ def pallas_call_hlo_interpret(
|
|||||||
debug: bool,
|
debug: bool,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
grid_mapping: GridMapping,
|
grid_mapping: GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
compiler_params: Any,
|
compiler_params: Any,
|
||||||
cost_estimate: CostEstimate,
|
cost_estimate: CostEstimate,
|
||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
):
|
):
|
||||||
del compiler_params, cost_estimate, out_avals
|
del mesh, compiler_params, cost_estimate, out_avals
|
||||||
debug_info = jaxpr.debug_info
|
debug_info = jaxpr.debug_info
|
||||||
# If we're in interpret mode, we *scan* over the grid and eval the
|
# If we're in interpret mode, we *scan* over the grid and eval the
|
||||||
# discharged jaxpr.
|
# discharged jaxpr.
|
||||||
|
@ -211,6 +211,10 @@ class TensorCoreMesh:
|
|||||||
devices: np.ndarray
|
devices: np.ndarray
|
||||||
axis_names: Sequence[str]
|
axis_names: Sequence[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backend(self) -> str:
|
||||||
|
return "mosaic_tpu"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
|
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
|
||||||
@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule(
|
|||||||
compiler_params = TPUCompilerParams()
|
compiler_params = TPUCompilerParams()
|
||||||
if len(mesh.shape) > 1:
|
if len(mesh.shape) > 1:
|
||||||
raise NotImplementedError("Mesh must be 1D")
|
raise NotImplementedError("Mesh must be 1D")
|
||||||
core_axis_name, num_cores = list(mesh.shape.items())[0]
|
|
||||||
if compiler_params.dimension_semantics is not None:
|
if compiler_params.dimension_semantics is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"dimension_semantics must be None for TensorCoreMesh"
|
"dimension_semantics must be None for TensorCoreMesh"
|
||||||
@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule(
|
|||||||
out_avals,
|
out_avals,
|
||||||
*args,
|
*args,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid=((core_axis_name, num_cores),),
|
mesh=mesh,
|
||||||
compiler_params=compiler_params.replace(
|
compiler_params=compiler_params.replace(
|
||||||
dimension_semantics=(PARALLEL,)
|
dimension_semantics=(PARALLEL,)
|
||||||
),
|
),
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
backend="mosaic_tpu",
|
|
||||||
cost_estimate=cost_estimate,
|
cost_estimate=cost_estimate,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
@ -1351,12 +1351,13 @@ def interpret_pallas_call(
|
|||||||
debug: bool,
|
debug: bool,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
grid_mapping: GridMapping,
|
grid_mapping: GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
compiler_params: Any,
|
compiler_params: Any,
|
||||||
cost_estimate: CostEstimate,
|
cost_estimate: CostEstimate,
|
||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
interpret_params: TPUInterpretParams,
|
interpret_params: TPUInterpretParams,
|
||||||
):
|
):
|
||||||
del debug, cost_estimate, out_avals
|
del debug, mesh, cost_estimate, out_avals
|
||||||
|
|
||||||
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
|
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
|
||||||
dynamic_grid_args, scalars, input_args = split_list(
|
dynamic_grid_args, scalars, input_args = split_list(
|
||||||
|
@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
|
|||||||
*in_nodes,
|
*in_nodes,
|
||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
grid_mapping: core.GridMapping,
|
grid_mapping: core.GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
debug: bool,
|
debug: bool,
|
||||||
interpret: bool,
|
interpret: bool,
|
||||||
@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
|
|||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
):
|
):
|
||||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||||
del interpret
|
del mesh, interpret # Unused.
|
||||||
|
|
||||||
debug_info = jaxpr._debug_info
|
debug_info = jaxpr._debug_info
|
||||||
if debug:
|
if debug:
|
||||||
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
|
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
|
||||||
@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
|
|||||||
else:
|
else:
|
||||||
mosaic_params = {}
|
mosaic_params = {}
|
||||||
|
|
||||||
mesh = None
|
jax_mesh = None
|
||||||
axis_context = ctx.module_context.axis_context
|
axis_context = ctx.module_context.axis_context
|
||||||
if axis_context is not None:
|
if axis_context is not None:
|
||||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||||
mesh = axis_context.mesh
|
jax_mesh = axis_context.mesh
|
||||||
mlir_ctx = mlir.JaxIrContext()
|
mlir_ctx = mlir.JaxIrContext()
|
||||||
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
|
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||||
mlir_ctx.load_all_available_dialects()
|
mlir_ctx.load_all_available_dialects()
|
||||||
@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule(
|
|||||||
grid_mapping,
|
grid_mapping,
|
||||||
jaxpr,
|
jaxpr,
|
||||||
dimension_semantics=dimension_semantics,
|
dimension_semantics=dimension_semantics,
|
||||||
mesh=mesh,
|
mesh=jax_mesh,
|
||||||
for_verification=for_verification,
|
for_verification=for_verification,
|
||||||
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
|
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
|
||||||
)
|
)
|
||||||
@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if promela_dump_path := _DUMP_PROMELA_TO.value:
|
if promela_dump_path := _DUMP_PROMELA_TO.value:
|
||||||
num_devices = 1 if mesh is None else mesh.devices.size
|
num_devices = 1 if jax_mesh is None else jax_mesh.devices.size
|
||||||
num_cores = (
|
num_cores = (
|
||||||
jax.devices()[0].num_cores
|
jax.devices()[0].num_cores
|
||||||
if mesh is None
|
if jax_mesh is None
|
||||||
else mesh.devices[0].num_cores
|
else jax_mesh.devices[0].num_cores
|
||||||
)
|
)
|
||||||
verification_module, _ = lower_module(for_verification=True)
|
verification_module, _ = lower_module(for_verification=True)
|
||||||
model = verification.export_promela_model(
|
model = verification.export_promela_model(
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
from collections.abc import Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import itertools as it
|
import itertools as it
|
||||||
@ -519,9 +519,16 @@ class GPUMesh:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def backend(self) -> str:
|
||||||
|
return "mosaic_gpu"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> collections.OrderedDict[object, int]:
|
||||||
|
pairs: Iterable[tuple[object, int]]
|
||||||
if self.num_threads is not None:
|
if self.num_threads is not None:
|
||||||
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
|
pairs = zip(
|
||||||
|
self.axis_names, (*self.grid, *self.cluster, self.num_threads)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pairs = tuple(
|
pairs = tuple(
|
||||||
zip(
|
zip(
|
||||||
@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule(
|
|||||||
out_avals,
|
out_avals,
|
||||||
*args,
|
*args,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid=tuple(mesh.shape.items()),
|
mesh=mesh,
|
||||||
backend="mosaic_gpu",
|
|
||||||
compiler_params=compiler_params,
|
compiler_params=compiler_params,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
|
@ -450,6 +450,7 @@ def _block_spec_from_block_mapping(
|
|||||||
|
|
||||||
def lower_pipelined_jaxpr_to_module(
|
def lower_pipelined_jaxpr_to_module(
|
||||||
grid_mapping: pallas_core.GridMapping,
|
grid_mapping: pallas_core.GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
compiler_params: dict[str, Any],
|
compiler_params: dict[str, Any],
|
||||||
cost_estimate: pallas_core.CostEstimate | None,
|
cost_estimate: pallas_core.CostEstimate | None,
|
||||||
@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module(
|
|||||||
block_mappings, [grid_mapping.num_inputs]
|
block_mappings, [grid_mapping.num_inputs]
|
||||||
)
|
)
|
||||||
|
|
||||||
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
|
if mesh is not None:
|
||||||
|
assert isinstance(mesh, gpu_core.GPUMesh)
|
||||||
|
if mesh and mesh.num_threads is not None:
|
||||||
|
# Last dim corresponds to the warpgroup count.
|
||||||
block = (128 * grid_mapping.grid[-1], 1, 1)
|
block = (128 * grid_mapping.grid[-1], 1, 1)
|
||||||
grid = grid_mapping.grid[:-1]
|
grid = grid_mapping.grid[:-1]
|
||||||
else:
|
else:
|
||||||
@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module(
|
|||||||
parallel_grid,
|
parallel_grid,
|
||||||
grid_mapping.grid_names,
|
grid_mapping.grid_names,
|
||||||
block,
|
block,
|
||||||
|
mesh.cluster if mesh is not None else (),
|
||||||
[bm.array_shape_dtype for bm in in_block_mappings],
|
[bm.array_shape_dtype for bm in in_block_mappings],
|
||||||
[bm.array_shape_dtype for bm in out_block_mappings],
|
[bm.array_shape_dtype for bm in out_block_mappings],
|
||||||
new_jaxpr,
|
new_jaxpr,
|
||||||
@ -578,6 +583,7 @@ def lower_jaxpr_to_module(
|
|||||||
grid: Sequence[int],
|
grid: Sequence[int],
|
||||||
grid_names: Sequence[str],
|
grid_names: Sequence[str],
|
||||||
block: Sequence[int],
|
block: Sequence[int],
|
||||||
|
cluster: Sequence[int],
|
||||||
in_shapes: Sequence[jax.ShapeDtypeStruct],
|
in_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||||
out_shapes: Sequence[jax.ShapeDtypeStruct],
|
out_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
@ -640,7 +646,7 @@ def lower_jaxpr_to_module(
|
|||||||
mgpu_core._lower_as_gpu_kernel(
|
mgpu_core._lower_as_gpu_kernel(
|
||||||
body,
|
body,
|
||||||
grid=parallel_grid,
|
grid=parallel_grid,
|
||||||
cluster=(),
|
cluster=cluster,
|
||||||
block=block,
|
block=block,
|
||||||
in_shapes=in_shapes,
|
in_shapes=in_shapes,
|
||||||
out_shape=out_shapes,
|
out_shape=out_shapes,
|
||||||
@ -1559,9 +1565,10 @@ def _reduce_lowering_rule_wg(
|
|||||||
if not out_aval.shape:
|
if not out_aval.shape:
|
||||||
# Special-case: reducing to a scalar.
|
# Special-case: reducing to a scalar.
|
||||||
if x_aval.ndim != 1:
|
if x_aval.ndim != 1:
|
||||||
# TODO(slebedev): Flatten to 1D, since vector.reduction only supports
|
# Flatten to 1D, since vector.reduction only supports 1D inputs.
|
||||||
# 1D inputs.
|
x = vector_dialect.shape_cast(
|
||||||
raise NotImplementedError("Only 1D inputs are supported")
|
ir.VectorType.get([x_aval.size], out_type), x
|
||||||
|
)
|
||||||
return vector_dialect.ReductionOp(out_type, kind, x)
|
return vector_dialect.ReductionOp(out_type, kind, x)
|
||||||
acc = vector_dialect.splat(
|
acc = vector_dialect.splat(
|
||||||
ir.VectorType.get(out_aval.shape, out_type),
|
ir.VectorType.get(out_aval.shape, out_type),
|
||||||
|
@ -38,6 +38,7 @@ def pallas_call_lowering(
|
|||||||
debug: bool,
|
debug: bool,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
grid_mapping: pallas_core.GridMapping,
|
grid_mapping: pallas_core.GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
compiler_params: dict[str, Any],
|
compiler_params: dict[str, Any],
|
||||||
cost_estimate: pallas_core.CostEstimate | None,
|
cost_estimate: pallas_core.CostEstimate | None,
|
||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
@ -63,6 +64,7 @@ def pallas_call_lowering(
|
|||||||
|
|
||||||
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
|
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
|
||||||
grid_mapping,
|
grid_mapping,
|
||||||
|
mesh,
|
||||||
jaxpr,
|
jaxpr,
|
||||||
compiler_params,
|
compiler_params,
|
||||||
cost_estimate,
|
cost_estimate,
|
||||||
|
@ -20,7 +20,7 @@ import dataclasses
|
|||||||
import enum
|
import enum
|
||||||
from functools import partial, reduce
|
from functools import partial, reduce
|
||||||
import types
|
import types
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@ -119,6 +119,7 @@ def _pallas_call_jvp_rule(
|
|||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
grid_mapping: GridMapping,
|
grid_mapping: GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
debug: bool,
|
debug: bool,
|
||||||
interpret: bool,
|
interpret: bool,
|
||||||
compiler_params: Any,
|
compiler_params: Any,
|
||||||
@ -133,6 +134,8 @@ def _pallas_call_jvp_rule(
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if input_output_aliases:
|
if input_output_aliases:
|
||||||
raise NotImplementedError("JVP with aliasing not supported.")
|
raise NotImplementedError("JVP with aliasing not supported.")
|
||||||
|
if mesh is not None:
|
||||||
|
raise NotImplementedError("pallas_call with a mesh does not support JVP")
|
||||||
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||||||
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||||
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
|
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
|
||||||
@ -181,6 +184,7 @@ def _pallas_call_jvp_rule(
|
|||||||
*tangents,
|
*tangents,
|
||||||
jaxpr=jvp_jaxpr,
|
jaxpr=jvp_jaxpr,
|
||||||
grid_mapping=jvp_grid_mapping,
|
grid_mapping=jvp_grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
input_output_aliases=(),
|
input_output_aliases=(),
|
||||||
@ -317,6 +321,7 @@ def _batch_with_explicit_loop(
|
|||||||
*,
|
*,
|
||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
grid_mapping: GridMapping,
|
grid_mapping: GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
debug: bool,
|
debug: bool,
|
||||||
interpret: bool,
|
interpret: bool,
|
||||||
@ -384,6 +389,7 @@ def _batch_with_explicit_loop(
|
|||||||
*batch_args,
|
*batch_args,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid_mapping=grid_mapping,
|
grid_mapping=grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
input_output_aliases=input_output_aliases,
|
input_output_aliases=input_output_aliases,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
@ -413,6 +419,7 @@ def _pallas_call_batching_rule(
|
|||||||
*,
|
*,
|
||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
grid_mapping: GridMapping,
|
grid_mapping: GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
debug: bool,
|
debug: bool,
|
||||||
interpret: bool,
|
interpret: bool,
|
||||||
@ -421,6 +428,11 @@ def _pallas_call_batching_rule(
|
|||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
backend: _Backend | None,
|
backend: _Backend | None,
|
||||||
):
|
):
|
||||||
|
if mesh is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"pallas_call with a mesh does not support batching"
|
||||||
|
)
|
||||||
|
|
||||||
def _maybe_squeeze_out_bdim(
|
def _maybe_squeeze_out_bdim(
|
||||||
x: jax.Array, bdim: int | batching.NotMapped
|
x: jax.Array, bdim: int | batching.NotMapped
|
||||||
) -> jax.Array:
|
) -> jax.Array:
|
||||||
@ -445,6 +457,7 @@ def _pallas_call_batching_rule(
|
|||||||
*args,
|
*args,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid_mapping=grid_mapping,
|
grid_mapping=grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
input_output_aliases=input_output_aliases,
|
input_output_aliases=input_output_aliases,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
@ -478,6 +491,7 @@ def _pallas_call_batching_rule(
|
|||||||
dims=dynamic_grid_dims + dims,
|
dims=dynamic_grid_dims + dims,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid_mapping=grid_mapping,
|
grid_mapping=grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
input_output_aliases=input_output_aliases,
|
input_output_aliases=input_output_aliases,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
@ -512,6 +526,7 @@ def _pallas_call_batching_rule(
|
|||||||
dims=scalar_bdims + bdims,
|
dims=scalar_bdims + bdims,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid_mapping=grid_mapping,
|
grid_mapping=grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
input_output_aliases=input_output_aliases,
|
input_output_aliases=input_output_aliases,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
@ -890,6 +905,7 @@ def _pallas_call_batching_rule(
|
|||||||
*args,
|
*args,
|
||||||
jaxpr=jaxpr,
|
jaxpr=jaxpr,
|
||||||
grid_mapping=batched_grid_mapping,
|
grid_mapping=batched_grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
input_output_aliases=input_output_aliases,
|
input_output_aliases=input_output_aliases,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule(
|
|||||||
jaxpr: jax_core.Jaxpr,
|
jaxpr: jax_core.Jaxpr,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
grid_mapping: GridMapping,
|
grid_mapping: GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
debug: bool,
|
debug: bool,
|
||||||
interpret: bool,
|
interpret: bool,
|
||||||
compiler_params: Any,
|
compiler_params: Any,
|
||||||
cost_estimate: CostEstimate | None,
|
cost_estimate: CostEstimate | None,
|
||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
backend: _Backend | None = None
|
backend: _Backend | None = None,
|
||||||
):
|
):
|
||||||
del avals_out
|
del avals_out
|
||||||
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
||||||
@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule(
|
|||||||
jaxpr=new_jaxpr,
|
jaxpr=new_jaxpr,
|
||||||
input_output_aliases=new_input_output_aliases,
|
input_output_aliases=new_input_output_aliases,
|
||||||
grid_mapping=new_grid_mapping,
|
grid_mapping=new_grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
compiler_params=compiler_params,
|
compiler_params=compiler_params,
|
||||||
@ -1526,16 +1544,6 @@ def pallas_call(
|
|||||||
invoke the Pallas kernel.
|
invoke the Pallas kernel.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if compiler_params is None:
|
|
||||||
compiler_params = {}
|
|
||||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
|
||||||
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
|
|
||||||
compiler_params = {
|
|
||||||
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
|
||||||
}
|
|
||||||
|
|
||||||
if grid_spec is None:
|
if grid_spec is None:
|
||||||
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
|
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
|
||||||
else:
|
else:
|
||||||
@ -1556,6 +1564,55 @@ def pallas_call(
|
|||||||
"If `grid_spec` is specified, then `scratch_shapes` must "
|
"If `grid_spec` is specified, then `scratch_shapes` must "
|
||||||
f"be `()`. It is {scratch_shapes}")
|
f"be `()`. It is {scratch_shapes}")
|
||||||
del grid, in_specs, out_specs
|
del grid, in_specs, out_specs
|
||||||
|
return _pallas_call(
|
||||||
|
kernel,
|
||||||
|
out_shape,
|
||||||
|
grid_spec=grid_spec,
|
||||||
|
input_output_aliases=input_output_aliases,
|
||||||
|
debug=debug,
|
||||||
|
interpret=interpret,
|
||||||
|
name=name,
|
||||||
|
compiler_params=compiler_params,
|
||||||
|
cost_estimate=cost_estimate,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pallas_call(
|
||||||
|
kernel: Callable[..., None],
|
||||||
|
out_shape: Any,
|
||||||
|
*,
|
||||||
|
grid_spec: GridSpec,
|
||||||
|
mesh: pallas_core.Mesh | None = None,
|
||||||
|
input_output_aliases: dict[int, int] = {},
|
||||||
|
debug: bool = False,
|
||||||
|
interpret: bool = False,
|
||||||
|
name: str | None = None,
|
||||||
|
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
|
||||||
|
cost_estimate: CostEstimate | None = None,
|
||||||
|
backend: _Backend | None = None,
|
||||||
|
):
|
||||||
|
if compiler_params is None:
|
||||||
|
compiler_params = {}
|
||||||
|
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||||
|
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown platform in compiler params: {compiler_params.PLATFORM}"
|
||||||
|
)
|
||||||
|
compiler_params = {
|
||||||
|
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mesh is not None:
|
||||||
|
if tuple(mesh.shape.values()) != grid_spec.grid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Mesh shape {tuple(mesh.shape.values())} does not match grid "
|
||||||
|
f"shape {grid_spec.grid}."
|
||||||
|
)
|
||||||
|
if backend is not None:
|
||||||
|
raise ValueError("If `mesh` is specified, then `backend` must be `None`.")
|
||||||
|
backend = cast(_Backend, mesh.backend)
|
||||||
|
|
||||||
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
||||||
# TODO(necula): this canonicalization may be convenient for some usage
|
# TODO(necula): this canonicalization may be convenient for some usage
|
||||||
# but it is lossy, because it prevents expressing functions that return
|
# but it is lossy, because it prevents expressing functions that return
|
||||||
@ -1643,6 +1700,7 @@ def pallas_call(
|
|||||||
debug=debug,
|
debug=debug,
|
||||||
interpret=interpret,
|
interpret=interpret,
|
||||||
grid_mapping=grid_mapping,
|
grid_mapping=grid_mapping,
|
||||||
|
mesh=mesh,
|
||||||
input_output_aliases=tuple(input_output_aliases.items()),
|
input_output_aliases=tuple(input_output_aliases.items()),
|
||||||
compiler_params=compiler_params,
|
compiler_params=compiler_params,
|
||||||
cost_estimate=cost_estimate,
|
cost_estimate=cost_estimate,
|
||||||
|
@ -50,6 +50,7 @@ def pallas_call_lowering(
|
|||||||
debug: bool,
|
debug: bool,
|
||||||
input_output_aliases: tuple[tuple[int, int], ...],
|
input_output_aliases: tuple[tuple[int, int], ...],
|
||||||
grid_mapping: pallas_core.GridMapping,
|
grid_mapping: pallas_core.GridMapping,
|
||||||
|
mesh: pallas_core.Mesh | None,
|
||||||
compiler_params: dict[str, Any],
|
compiler_params: dict[str, Any],
|
||||||
cost_estimate: pallas_core.CostEstimate | None,
|
cost_estimate: pallas_core.CostEstimate | None,
|
||||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||||
@ -64,6 +65,8 @@ def pallas_call_lowering(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"scalar prefetch not implemented in the Triton backend"
|
"scalar prefetch not implemented in the Triton backend"
|
||||||
)
|
)
|
||||||
|
if mesh is not None:
|
||||||
|
raise NotImplementedError("mesh is not supported in the Triton backend")
|
||||||
triton_params = compiler_params.get("triton", compiler_params)
|
triton_params = compiler_params.get("triton", compiler_params)
|
||||||
num_warps = triton_params.get("num_warps", 4)
|
num_warps = triton_params.get("num_warps", 4)
|
||||||
num_warps = 4 if num_warps is None else num_warps
|
num_warps = 4 if num_warps is None else num_warps
|
||||||
|
@ -670,8 +670,8 @@ def choice(key: ArrayLike,
|
|||||||
ind = jnp.searchsorted(p_cuml, r).astype(int)
|
ind = jnp.searchsorted(p_cuml, r).astype(int)
|
||||||
else:
|
else:
|
||||||
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||||
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
|
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
|
||||||
ind = jnp.argsort(g)[:n_draws]
|
ind = lax.top_k(g, k=n_draws)[1].astype(int)
|
||||||
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
||||||
|
|
||||||
return result.reshape(shape if arr.ndim == 0 else
|
return result.reshape(shape if arr.ndim == 0 else
|
||||||
@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
|
|||||||
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
||||||
|
|
||||||
|
|
||||||
def categorical(key: ArrayLike,
|
def categorical(
|
||||||
|
key: ArrayLike,
|
||||||
logits: RealArray,
|
logits: RealArray,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
shape: Shape | None = None) -> Array:
|
shape: Shape | None = None,
|
||||||
|
replace: bool = True,
|
||||||
|
) -> Array:
|
||||||
"""Sample random values from categorical distributions.
|
"""Sample random values from categorical distributions.
|
||||||
|
|
||||||
|
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
|
||||||
|
the Gumbel top-k trick. See [1] for reference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: a PRNG key used as the random key.
|
key: a PRNG key used as the random key.
|
||||||
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
||||||
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
|
|||||||
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
||||||
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
||||||
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
||||||
|
replace: If True, perform sampling without replacement. Default (False) is to
|
||||||
|
perform sampling with replacement.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A random array with int dtype and shape given by ``shape`` if ``shape``
|
A random array with int dtype and shape given by ``shape`` if ``shape``
|
||||||
is not None, or else ``np.delete(logits.shape, axis)``.
|
is not None, or else ``np.delete(logits.shape, axis)``.
|
||||||
|
|
||||||
|
References:
|
||||||
|
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
|
||||||
|
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
|
||||||
|
Proceedings of the 36th International Conference on Machine Learning, PMLR
|
||||||
|
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
|
||||||
"""
|
"""
|
||||||
key, _ = _check_prng_key("categorical", key)
|
key, _ = _check_prng_key("categorical", key)
|
||||||
check_arraylike("categorical", logits)
|
check_arraylike("categorical", logits)
|
||||||
logits_arr = jnp.asarray(logits)
|
logits_arr = jnp.asarray(logits)
|
||||||
|
|
||||||
if axis >= 0:
|
|
||||||
axis -= len(logits_arr.shape)
|
|
||||||
|
|
||||||
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = batch_shape
|
shape = batch_shape
|
||||||
else:
|
else:
|
||||||
shape = core.canonicalize_shape(shape)
|
shape = core.canonicalize_shape(shape)
|
||||||
_check_shape("categorical", shape, batch_shape)
|
_check_shape("categorical", shape, batch_shape)
|
||||||
|
|
||||||
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
||||||
|
|
||||||
|
if replace:
|
||||||
|
if axis >= 0:
|
||||||
|
axis -= len(logits_arr.shape)
|
||||||
|
|
||||||
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
||||||
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
||||||
return jnp.argmax(
|
return jnp.argmax(
|
||||||
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
||||||
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
||||||
axis=axis)
|
axis=axis)
|
||||||
|
else:
|
||||||
|
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
|
||||||
|
k = math.prod(shape_prefix)
|
||||||
|
if k > logits_arr.shape[axis]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of samples without replacement ({k}) cannot exceed number of "
|
||||||
|
f"categories ({logits_arr.shape[axis]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
|
||||||
|
assert indices.shape == batch_shape + (k,)
|
||||||
|
assert shape == shape_prefix + batch_shape
|
||||||
|
|
||||||
|
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
|
||||||
|
indices = lax.reshape(indices, shape, dimensions)
|
||||||
|
assert indices.shape == shape
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
def laplace(key: ArrayLike,
|
def laplace(key: ArrayLike,
|
||||||
|
@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
|
|||||||
return sdy_sharding
|
return sdy_sharding
|
||||||
|
|
||||||
|
|
||||||
@util.cache(max_size=128, trace_context_in_key=False)
|
replicated_hlo_sharding = xc.HloSharding.replicate()
|
||||||
def get_replicated_hlo_sharding():
|
|
||||||
return xc.HloSharding.replicate()
|
|
||||||
|
|
||||||
|
|
||||||
@use_cpp_class(xc.SingleDeviceSharding)
|
@use_cpp_class(xc.SingleDeviceSharding)
|
||||||
@ -183,7 +181,7 @@ class SingleDeviceSharding(jsharding.Sharding):
|
|||||||
return (self._device,)
|
return (self._device,)
|
||||||
|
|
||||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||||
return get_replicated_hlo_sharding()
|
return replicated_hlo_sharding
|
||||||
|
|
||||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||||
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
|
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
|
||||||
@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding(
|
|||||||
def _positional_sharding_to_xla_hlo_sharding(
|
def _positional_sharding_to_xla_hlo_sharding(
|
||||||
self, num_dimensions: int) -> xc.HloSharding:
|
self, num_dimensions: int) -> xc.HloSharding:
|
||||||
if self.shape == (1,) * self.ndim:
|
if self.shape == (1,) * self.ndim:
|
||||||
return get_replicated_hlo_sharding()
|
return replicated_hlo_sharding
|
||||||
|
|
||||||
pbuf = xc.OpSharding()
|
pbuf = xc.OpSharding()
|
||||||
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
|
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
|
||||||
@ -603,7 +601,7 @@ class GSPMDSharding(jsharding.Sharding):
|
|||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def _hlo_sharding_hash(self):
|
def _hlo_sharding_hash(self):
|
||||||
if self.is_fully_replicated:
|
if self.is_fully_replicated:
|
||||||
return hash(get_replicated_hlo_sharding())
|
return hash(replicated_hlo_sharding)
|
||||||
return hash(self._hlo_sharding)
|
return hash(self._hlo_sharding)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -669,7 +667,7 @@ class GSPMDSharding(jsharding.Sharding):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
|
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
|
||||||
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
|
return cls(tuple(device_assignment), replicated_hlo_sharding,
|
||||||
memory_kind=memory_kind)
|
memory_kind=memory_kind)
|
||||||
|
|
||||||
|
|
||||||
|
@ -244,8 +244,15 @@ def curry(f):
|
|||||||
"""
|
"""
|
||||||
return wraps(f)(partial(partial, f))
|
return wraps(f)(partial(partial, f))
|
||||||
|
|
||||||
|
# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
|
||||||
|
toposort: Callable[[Iterable[Any]], list[Any]]
|
||||||
|
if hasattr(jaxlib_utils, "topological_sort"):
|
||||||
|
toposort = partial(jaxlib_utils.topological_sort, "parents")
|
||||||
|
else:
|
||||||
|
|
||||||
def toposort(end_nodes):
|
def toposort(end_nodes):
|
||||||
if not end_nodes: return []
|
if not end_nodes:
|
||||||
|
return []
|
||||||
end_nodes = _remove_duplicates(end_nodes)
|
end_nodes = _remove_duplicates(end_nodes)
|
||||||
|
|
||||||
child_counts = {}
|
child_counts = {}
|
||||||
@ -261,7 +268,9 @@ def toposort(end_nodes):
|
|||||||
child_counts[id(node)] -= 1
|
child_counts[id(node)] -= 1
|
||||||
|
|
||||||
sorted_nodes = []
|
sorted_nodes = []
|
||||||
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
|
childless_nodes = [
|
||||||
|
node for node in end_nodes if child_counts[id(node)] == 0
|
||||||
|
]
|
||||||
assert childless_nodes
|
assert childless_nodes
|
||||||
while childless_nodes:
|
while childless_nodes:
|
||||||
node = childless_nodes.pop()
|
node = childless_nodes.pop()
|
||||||
@ -291,6 +300,7 @@ def _remove_duplicates(node_list):
|
|||||||
out.append(n)
|
out.append(n)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def split_merge(predicate, xs):
|
def split_merge(predicate, xs):
|
||||||
sides = list(map(predicate, xs))
|
sides = list(map(predicate, xs))
|
||||||
lhs = [x for x, s in zip(xs, sides) if s]
|
lhs = [x for x, s in zip(xs, sides) if s]
|
||||||
@ -658,17 +668,12 @@ def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]:
|
|||||||
|
|
||||||
exclude_methods = {'__module__', '__dict__', '__doc__'}
|
exclude_methods = {'__module__', '__dict__', '__doc__'}
|
||||||
|
|
||||||
originals = {}
|
|
||||||
for attr_name, attr in cls.__dict__.items():
|
for attr_name, attr in cls.__dict__.items():
|
||||||
if attr_name not in exclude_methods:
|
if attr_name not in exclude_methods:
|
||||||
if hasattr(_original_func(attr), "_use_cpp"):
|
if not hasattr(_original_func(attr), "_use_cpp"):
|
||||||
originals[attr_name] = attr
|
|
||||||
else:
|
|
||||||
setattr(cpp_cls, attr_name, attr)
|
setattr(cpp_cls, attr_name, attr)
|
||||||
|
|
||||||
cpp_cls.__doc__ = cls.__doc__
|
cpp_cls.__doc__ = cls.__doc__
|
||||||
# TODO(pschuh): Remove once fastpath is gone.
|
|
||||||
cpp_cls._original_py_fns = originals
|
|
||||||
return cpp_cls
|
return cpp_cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for serialization and deserialization of GDA."""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for mnist_lib, saved_model_lib, saved_model_main."""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for call_tf."""
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import contextlib
|
import contextlib
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for the jax2tf conversion for control-flow primitives."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for the shape-polymorphic jax2tf conversion."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
@ -320,6 +320,20 @@ def _vector_splat_op_lowering_rule(
|
|||||||
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
|
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
|
||||||
|
|
||||||
|
|
||||||
|
@_register_lowering(vector.ShapeCastOp)
|
||||||
|
def _vector_shape_cast_op_lowering_rule(
|
||||||
|
_: LoweringContext, op: vector.ShapeCastOp
|
||||||
|
) -> Sequence[ir.Value]:
|
||||||
|
[layout] = inference_utils.in_layouts(op)
|
||||||
|
out_vec_ty = ir.VectorType(op.result.type)
|
||||||
|
assert out_vec_ty.has_static_shape
|
||||||
|
is_signed = (
|
||||||
|
False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None
|
||||||
|
)
|
||||||
|
a = _fragmented_array_from_ir(op.source, layout, is_signed)
|
||||||
|
return [_fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)]
|
||||||
|
|
||||||
|
|
||||||
@_register_lowering(vector.ReductionOp)
|
@_register_lowering(vector.ReductionOp)
|
||||||
def _vector_reduction_op_lowering_rule(
|
def _vector_reduction_op_lowering_rule(
|
||||||
ctx: LoweringContext, op: vector.ReductionOp
|
ctx: LoweringContext, op: vector.ReductionOp
|
||||||
|
@ -382,21 +382,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]):
|
|||||||
return WGMMA_LAYOUT
|
return WGMMA_LAYOUT
|
||||||
|
|
||||||
|
|
||||||
def _tiled_wgmma_layout_for_upcast(shape: tuple[int, ...]):
|
|
||||||
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
|
|
||||||
if len(shape) != 2:
|
|
||||||
raise ValueError(f"Shape {shape} is not 2D")
|
|
||||||
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
|
|
||||||
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
|
||||||
t = Tiling(((64, 16), (16, 16), (8, 16), (4,), (2, 1)))
|
|
||||||
return TiledLayout(
|
|
||||||
t,
|
|
||||||
warp_dim=-9,
|
|
||||||
lane_dims=(-5, -2, -4),
|
|
||||||
vector_dim=-3,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class WGMMARowFragLayout:
|
class WGMMARowFragLayout:
|
||||||
"""[m] matrix, where m % 64 == 0."""
|
"""[m] matrix, where m % 64 == 0."""
|
||||||
@ -505,13 +490,55 @@ WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
|
|||||||
|
|
||||||
# The tiled layout is equivalent to one described here in PTX documentation:
|
# The tiled layout is equivalent to one described here in PTX documentation:
|
||||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
|
||||||
|
# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles.
|
||||||
|
# Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit
|
||||||
|
# of data that is split across a warp. Since 8*8 = 64, but a warp has only 32
|
||||||
|
# threads, we vectorize pairs of elements along columns.
|
||||||
|
# The assignment of elements to warp lanes is as follows:
|
||||||
|
#
|
||||||
|
# 0 0 1 1 2 2 3 3
|
||||||
|
# 4 4 5 5 6 6 7 7
|
||||||
|
# 8 8 9 9 10 10 11 11
|
||||||
|
# 12 12 13 13 14 14 15 15
|
||||||
|
# ...
|
||||||
WGMMA_LAYOUT = TiledLayout(
|
WGMMA_LAYOUT = TiledLayout(
|
||||||
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
|
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
|
||||||
warp_dim=-8,
|
warp_dim=-8,
|
||||||
lane_dims=(-4, -3),
|
lane_dims=(-4, -3),
|
||||||
vector_dim=-1,
|
vector_dim=-1,
|
||||||
)
|
)
|
||||||
# This tiled layout is similar to the one above. Above, each warp stores a 8x8
|
# This tiled layout is similar to the WGMMA layout, only the unit at which we
|
||||||
|
# assign submatrices to warps grows from 8x8 to 8x16. The elements within each
|
||||||
|
# submatrix are assigned to threads in the following way:
|
||||||
|
#
|
||||||
|
# 0 0 0 0 2 2 2 2 1 1 1 1 3 3 3 3
|
||||||
|
# 4 4 4 4 6 6 6 6 5 5 5 5 7 7 7 7
|
||||||
|
# ...
|
||||||
|
#
|
||||||
|
# Our vector length is twice the size of that of WGMMA_LAYOUT, which lets us use
|
||||||
|
# 32-bit SMEM loads/stores when dealing with 8-bit values. The conversion
|
||||||
|
# to the WGMMA layout only requires communication between with index differing
|
||||||
|
# in their 2 bit (i.e. 0 and 1, 2 and 4), so the conversion to WGMMA_LAYOUT
|
||||||
|
# only requires a single warp shuffle (plus permutes local to each thread).
|
||||||
|
WGMMA_LAYOUT_UPCAST_2X = TiledLayout(
|
||||||
|
Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))),
|
||||||
|
warp_dim=-8,
|
||||||
|
lane_dims=(-4, -2, -3),
|
||||||
|
vector_dim=-1,
|
||||||
|
)
|
||||||
|
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
|
||||||
|
# purpose of passing them into WGMMA later. The core matrices stored by a warp
|
||||||
|
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
|
||||||
|
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
|
||||||
|
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
|
||||||
|
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
|
||||||
|
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
|
||||||
|
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
|
||||||
|
warp_dim=-7,
|
||||||
|
lane_dims=(-3, -2),
|
||||||
|
vector_dim=-1,
|
||||||
|
)
|
||||||
|
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
|
||||||
# submatrix in the following way (we only show the first 4 rows for brevity):
|
# submatrix in the following way (we only show the first 4 rows for brevity):
|
||||||
#
|
#
|
||||||
# 0 0 1 1 2 2 3 3
|
# 0 0 1 1 2 2 3 3
|
||||||
@ -697,6 +724,7 @@ class FragmentedArray:
|
|||||||
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
|
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
|
||||||
"""
|
"""
|
||||||
i32 = ir.IntegerType.get_signless(32)
|
i32 = ir.IntegerType.get_signless(32)
|
||||||
|
c = lambda x: arith.constant(i32, x)
|
||||||
if self.layout == new_layout:
|
if self.layout == new_layout:
|
||||||
return self
|
return self
|
||||||
shape = self.shape
|
shape = self.shape
|
||||||
@ -707,24 +735,148 @@ class FragmentedArray:
|
|||||||
):
|
):
|
||||||
is_even_row = arith.cmpi(
|
is_even_row = arith.cmpi(
|
||||||
arith.CmpIPredicate.eq,
|
arith.CmpIPredicate.eq,
|
||||||
arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)),
|
arith.remui(arith.divui(utils.thread_idx(), c(4)), c(2)),
|
||||||
c(0, i32),
|
c(0),
|
||||||
)
|
)
|
||||||
perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32))
|
perm = arith.select(is_even_row, c(0x5410), c(0x3276))
|
||||||
new_regs = []
|
new_regs = []
|
||||||
for reg in self.registers.flat:
|
for reg in self.registers.flat:
|
||||||
reg_ty = reg.type
|
reg_ty = reg.type
|
||||||
reg = utils.bitcast(reg, i32)
|
reg = utils.bitcast(reg, i32)
|
||||||
reg_shfl = utils.shfl_bfly(reg, 4)
|
reg_shfl = utils.shfl_bfly(reg, 4)
|
||||||
new_reg = llvm.inline_asm(
|
new_reg = utils.prmt(reg, reg_shfl, perm)
|
||||||
i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
|
|
||||||
)
|
|
||||||
new_regs.append(utils.bitcast(new_reg, reg_ty))
|
new_regs.append(utils.bitcast(new_reg, reg_ty))
|
||||||
return FragmentedArray(
|
return FragmentedArray(
|
||||||
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
|
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
|
||||||
_layout=new_layout,
|
_layout=new_layout,
|
||||||
_is_signed=self.is_signed,
|
_is_signed=self.is_signed,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
self.layout == WGMMA_LAYOUT_UPCAST_2X
|
||||||
|
and new_layout == WGMMA_LAYOUT
|
||||||
|
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
|
||||||
|
):
|
||||||
|
assert shape[1] % 16 == 0 # Should be implied by the layout
|
||||||
|
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||||
|
is_even = arith.cmpi(
|
||||||
|
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
|
||||||
|
)
|
||||||
|
registers = self.registers
|
||||||
|
if dtype_bitwidth == 4:
|
||||||
|
if registers.shape[1] % 2:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This relayout implementation requires an even number of column"
|
||||||
|
" tiles (to pack pairs of them for efficiency)"
|
||||||
|
)
|
||||||
|
# We pair up the consecutive column tiles, so each register is 32-bit.
|
||||||
|
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
|
||||||
|
# LLVM will realize that the paired up vectors actually came from the
|
||||||
|
# same 32-bit register and it will become a no-op.
|
||||||
|
col_minor_registers = np.moveaxis(registers, 1, -1)
|
||||||
|
flat_registers = [
|
||||||
|
utils.vector_concat((l, h))
|
||||||
|
for l, h in zip(
|
||||||
|
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
registers = np.asarray(flat_registers, dtype=object).reshape(
|
||||||
|
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
|
||||||
|
)
|
||||||
|
registers = np.moveaxis(registers, -1, 1)
|
||||||
|
for idx, reg in np.ndenumerate(registers):
|
||||||
|
if dtype_bitwidth == 16:
|
||||||
|
assert reg.type.shape == [4]
|
||||||
|
# A single vector is 64-bits, but shuffles are only 32-bit wide.
|
||||||
|
# We only shuffle the half that needs to go to other thread.
|
||||||
|
low = utils.vector_slice(reg, slice(0, 2))
|
||||||
|
high = utils.vector_slice(reg, slice(2, 4))
|
||||||
|
to_exchange = arith.select(is_even, high, low)
|
||||||
|
# Exchange values between even and odd threads.
|
||||||
|
exchanged = utils.shfl_bfly(to_exchange, 1)
|
||||||
|
low = arith.select(is_even, low, exchanged)
|
||||||
|
high = arith.select(is_even, exchanged, high)
|
||||||
|
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
|
||||||
|
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
|
||||||
|
elif dtype_bitwidth == 8:
|
||||||
|
assert reg.type.shape == [4]
|
||||||
|
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||||
|
# use prmt to blend it with the local register.
|
||||||
|
exchanged = utils.shfl_bfly(reg, 1)
|
||||||
|
# Consider lanes 0 and 1, because the situation is symmetric for
|
||||||
|
# each pair. If we feed reg[lane] and exchanged[lane] (which is
|
||||||
|
# really the same as reg of the other lane) to prmt, we can index
|
||||||
|
# the elements of the result using the following indices:
|
||||||
|
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
|
||||||
|
# prmt[0]: 0 1 2 3 4 5 6 7
|
||||||
|
# prmt[1]: 4 5 6 7 0 1 2 3
|
||||||
|
# The expected outputs and their respective permutations are:
|
||||||
|
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
|
||||||
|
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
|
||||||
|
# Note that the patterns still need to be flipped, since we listed
|
||||||
|
# bytes with LSB on the left, which is the opposite of how the
|
||||||
|
# numeric constants are spelled in Python (LSB on the right).
|
||||||
|
perm = arith.select(is_even, c(0x5410), c(0x3276))
|
||||||
|
blend = utils.prmt(reg, exchanged, perm)
|
||||||
|
for i in range(2):
|
||||||
|
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
||||||
|
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
||||||
|
else:
|
||||||
|
assert dtype_bitwidth == 4
|
||||||
|
assert reg.type.shape == [8] # We paired up the registers above.
|
||||||
|
exchanged = utils.shfl_bfly(reg, 1)
|
||||||
|
# See comment above for a more complete explanation.
|
||||||
|
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
|
||||||
|
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
|
||||||
|
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
|
||||||
|
# The expected outputs and their respective permutations are:
|
||||||
|
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
|
||||||
|
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
|
||||||
|
perm = arith.select(is_even, c(0x6240), c(0x3715))
|
||||||
|
blend = utils.prmt(reg, exchanged, perm)
|
||||||
|
for i in range(4):
|
||||||
|
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
||||||
|
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
|
||||||
|
assert all(r is not None for r in new_registers)
|
||||||
|
return FragmentedArray(
|
||||||
|
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.layout == WGMMA_LAYOUT_UPCAST_4X
|
||||||
|
and new_layout == WGMMA_LAYOUT_UPCAST_2X
|
||||||
|
and utils.bitwidth(self.mlir_dtype) == 4
|
||||||
|
):
|
||||||
|
assert shape[0] % 64 == 0 # Should be implied by the layout
|
||||||
|
assert shape[1] % 32 == 0 # Should be implied by the layout
|
||||||
|
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||||
|
i32 = ir.IntegerType.get_signless(32)
|
||||||
|
c = lambda x: arith.constant(i32, x)
|
||||||
|
is_01 = arith.cmpi(
|
||||||
|
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
|
||||||
|
)
|
||||||
|
for idx, reg in np.ndenumerate(self.registers):
|
||||||
|
assert ir.VectorType(reg.type).shape == [8]
|
||||||
|
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||||
|
# use prmt to blend it with the local register.
|
||||||
|
exchanged = utils.shfl_bfly(reg, 2)
|
||||||
|
# See comments above for conventions. Here we exchange data between
|
||||||
|
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
|
||||||
|
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
|
||||||
|
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
|
||||||
|
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
|
||||||
|
# The expected outputs and their respective permutations are:
|
||||||
|
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
|
||||||
|
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
|
||||||
|
perm = arith.select(is_01, c(0x5410), c(0x3276))
|
||||||
|
blend = utils.prmt(reg, exchanged, perm)
|
||||||
|
for i in range(2):
|
||||||
|
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
|
||||||
|
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
||||||
|
assert all(r is not None for r in new_registers)
|
||||||
|
return FragmentedArray(
|
||||||
|
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||||
|
)
|
||||||
|
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
|
||||||
|
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
|
||||||
if not isinstance(self.layout, WGSplatFragLayout):
|
if not isinstance(self.layout, WGSplatFragLayout):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Cannot convert from {self.layout} to {new_layout}"
|
f"Cannot convert from {self.layout} to {new_layout}"
|
||||||
@ -1178,11 +1330,15 @@ class FragmentedArray:
|
|||||||
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
||||||
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
||||||
[vector_len] = reg_shape # This is meant to be a 1D assertion.
|
[vector_len] = reg_shape # This is meant to be a 1D assertion.
|
||||||
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
|
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
|
||||||
|
raise ValueError(
|
||||||
|
"Register bitwidth in target type must be divisible by 8, got"
|
||||||
|
f" {new_reg_bitwidth}"
|
||||||
|
)
|
||||||
|
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
|
||||||
new_registers = np.empty_like(self.registers)
|
new_registers = np.empty_like(self.registers)
|
||||||
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
|
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
|
||||||
for idx, reg in np.ndenumerate(self.registers):
|
for idx, reg in np.ndenumerate(self.registers):
|
||||||
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
|
|
||||||
# The algorithm here is largely the same as CUTLASS's
|
# The algorithm here is largely the same as CUTLASS's
|
||||||
# NumericArrayConverter specialization for int4 -> bf16 casts.
|
# NumericArrayConverter specialization for int4 -> bf16 casts.
|
||||||
# We modify it slightly, because we only extract 2 values.
|
# We modify it slightly, because we only extract 2 values.
|
||||||
@ -1196,25 +1352,58 @@ class FragmentedArray:
|
|||||||
# positive int4s will end up larger than negative int4s, with a bias of
|
# positive int4s will end up larger than negative int4s, with a bias of
|
||||||
# 8. Use use the sub to subtract the base (our initial exponent) and the
|
# 8. Use use the sub to subtract the base (our initial exponent) and the
|
||||||
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
|
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
|
||||||
new_reg_32 = llvm.inline_asm(
|
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
|
||||||
|
assert 0 <= part < 4
|
||||||
|
return llvm.inline_asm(
|
||||||
i32,
|
i32,
|
||||||
[reg_8],
|
[reg, reg_shr],
|
||||||
"""
|
f"""
|
||||||
{
|
{{
|
||||||
.reg .b32 s<4>;
|
.reg .b32 s<4>;
|
||||||
shr.s32 s0, $1, 4;
|
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
|
||||||
prmt.b32 s1, $1, s0, 0xF4F0;
|
|
||||||
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
||||||
mov.b32 s3, 0x43084308;
|
mov.b32 s3, 0x43084308;
|
||||||
sub.bf16x2 $0, s2, s3;
|
sub.bf16x2 $0, s2, s3;
|
||||||
}
|
}}
|
||||||
""",
|
""",
|
||||||
"=r,r",
|
"=r,r,r",
|
||||||
)
|
)
|
||||||
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
|
offset = 0
|
||||||
new_registers[idx] = vector.bitcast(
|
out_int_regs = []
|
||||||
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
|
for group_size in (8, 4, 2):
|
||||||
|
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
||||||
|
while vector_len - offset >= group_size:
|
||||||
|
# If the vector originates from a slice (common after relayouts), we
|
||||||
|
# can fuse the slicing into the conversion and prevent LLVM from
|
||||||
|
# generating a bunch of shifts to align the vector data to the LSB.
|
||||||
|
# This also lets us share the right shift among more vectors.
|
||||||
|
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
|
||||||
|
and utils.bitwidth(slice_op.vector.type) == 32
|
||||||
|
and slice_op.strides[0].value == 1):
|
||||||
|
slice_offset = slice_op.offsets[0].value + offset
|
||||||
|
reg_int = utils.bitcast(slice_op.vector, i32)
|
||||||
|
reg_int_shr = arith.shrui(reg_int, c(4, i32))
|
||||||
|
out_int_regs.extend(
|
||||||
|
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
|
||||||
|
for part in range(group_size // 2)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||||
|
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
||||||
|
if int_ty != i32:
|
||||||
|
reg_slice_int = arith.extsi(i32, reg_slice_int)
|
||||||
|
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
||||||
|
out_int_regs.extend(
|
||||||
|
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
||||||
|
for part in range(group_size // 2)
|
||||||
|
)
|
||||||
|
offset += group_size
|
||||||
|
assert offset == vector_len
|
||||||
|
out_vec_int = utils.vector_concat([
|
||||||
|
vector.splat(ir.VectorType.get((1,), i32), reg)
|
||||||
|
for reg in out_int_regs
|
||||||
|
])
|
||||||
|
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
|
||||||
return FragmentedArray(
|
return FragmentedArray(
|
||||||
_registers=new_registers, _layout=self.layout, _is_signed=None
|
_registers=new_registers, _layout=self.layout, _is_signed=None
|
||||||
)
|
)
|
||||||
@ -1263,11 +1452,6 @@ class FragmentedArray:
|
|||||||
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
|
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
|
||||||
)
|
)
|
||||||
# Generic path.
|
# Generic path.
|
||||||
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
|
|
||||||
# same endianness as the target machine (which is little for NVIDIA GPUs).
|
|
||||||
# We'll need to add specialized casting routines that flip the endianness.
|
|
||||||
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
|
|
||||||
raise NotImplementedError("Conversion involving sub-byte types unsupported")
|
|
||||||
from_float = ir.FloatType.isinstance(cur_dtype)
|
from_float = ir.FloatType.isinstance(cur_dtype)
|
||||||
to_float = ir.FloatType.isinstance(new_dtype)
|
to_float = ir.FloatType.isinstance(new_dtype)
|
||||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||||
@ -1472,17 +1656,17 @@ class FragmentedArray:
|
|||||||
def reshape(self, shape):
|
def reshape(self, shape):
|
||||||
if self.shape == shape:
|
if self.shape == shape:
|
||||||
return self
|
return self
|
||||||
|
if math.prod(shape) != math.prod(self.shape):
|
||||||
if not isinstance(self.layout, WGSplatFragLayout):
|
|
||||||
raise NotImplementedError(self.layout)
|
|
||||||
|
|
||||||
if np.prod(shape) != np.prod(self.shape):
|
|
||||||
raise ValueError(f"Can't reshape {self.shape} to {shape}")
|
raise ValueError(f"Can't reshape {self.shape} to {shape}")
|
||||||
|
|
||||||
|
match self.layout:
|
||||||
|
case WGSplatFragLayout() | WGStridedFragLayout():
|
||||||
|
new_layout = dataclasses.replace(self.layout, shape=shape)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(self.layout)
|
||||||
|
|
||||||
return FragmentedArray(
|
return FragmentedArray(
|
||||||
_registers=self.registers,
|
_registers=self.registers, _layout=new_layout, _is_signed=self.is_signed
|
||||||
_layout=WGSplatFragLayout(shape),
|
|
||||||
_is_signed=self.is_signed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def broadcast_minor(self, n):
|
def broadcast_minor(self, n):
|
||||||
|
@ -336,6 +336,37 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
|
|||||||
|
|
||||||
return [], [layout]
|
return [], [layout]
|
||||||
|
|
||||||
|
|
||||||
|
def _update_layout_shape(
|
||||||
|
layout: ir.Attribute, shape: Sequence[int], origin: str
|
||||||
|
) -> ir.Attribute:
|
||||||
|
if layouts_lib.is_splat_fragmented_layout(
|
||||||
|
layout
|
||||||
|
) or layouts_lib.is_strided_fragmented_layout(layout):
|
||||||
|
return layouts_lib.to_layout_attr(
|
||||||
|
dataclasses.replace(layouts_lib.from_layout_attr(layout), shape=shape)
|
||||||
|
)
|
||||||
|
raise NotImplementedError(f"Unsupported {origin} layout: {layout}.")
|
||||||
|
|
||||||
|
|
||||||
|
@partial(_add_layout_inference_rule, vector.ShapeCastOp)
|
||||||
|
def _infer_shape_cast_op_layout(op: vector.ShapeCastOp) -> OptionalLayouts:
|
||||||
|
in_layout = inference_utils.value_layout(op.source)
|
||||||
|
if in_layout is None:
|
||||||
|
out_layout = inference_utils.value_layout(op.result)
|
||||||
|
if out_layout is None:
|
||||||
|
return None
|
||||||
|
in_layout = _update_layout_shape(
|
||||||
|
out_layout, ir.VectorType(op.source.type).shape, "source"
|
||||||
|
)
|
||||||
|
return [in_layout], [out_layout]
|
||||||
|
|
||||||
|
out_layout = _update_layout_shape(
|
||||||
|
in_layout, ir.VectorType(op.result.type).shape, "result"
|
||||||
|
)
|
||||||
|
return [in_layout], [out_layout]
|
||||||
|
|
||||||
|
|
||||||
@partial(_add_layout_inference_rule, vector.ReductionOp)
|
@partial(_add_layout_inference_rule, vector.ReductionOp)
|
||||||
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
|
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
|
||||||
if layout := inference_utils.value_layout(op.vector):
|
if layout := inference_utils.value_layout(op.vector):
|
||||||
|
@ -83,6 +83,8 @@ def mma(
|
|||||||
accumulate: ir.Value | bool = True,
|
accumulate: ir.Value | bool = True,
|
||||||
collective: bool = False,
|
collective: bool = False,
|
||||||
):
|
):
|
||||||
|
if a_swizzle == 16 or b_swizzle == 16:
|
||||||
|
raise NotImplementedError("No swizzle is not supported")
|
||||||
i32 = ir.IntegerType.get_signless(32)
|
i32 = ir.IntegerType.get_signless(32)
|
||||||
i64 = ir.IntegerType.get_signless(64)
|
i64 = ir.IntegerType.get_signless(64)
|
||||||
if isinstance(accumulate, bool):
|
if isinstance(accumulate, bool):
|
||||||
|
@ -25,8 +25,12 @@ from typing import cast
|
|||||||
|
|
||||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
|
from jax._src.lib.mlir.dialects import arith
|
||||||
|
from jax._src.lib.mlir.dialects import vector
|
||||||
|
|
||||||
|
from . import fragmented_array as fa
|
||||||
from . import inference_utils
|
from . import inference_utils
|
||||||
|
from . import layouts as layouts_lib
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
@ -39,7 +43,9 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {}
|
|||||||
def _add_transform_inference_rule(
|
def _add_transform_inference_rule(
|
||||||
op: type[ir.OpView], rule: TransformInferenceRule
|
op: type[ir.OpView], rule: TransformInferenceRule
|
||||||
):
|
):
|
||||||
|
if op is not None:
|
||||||
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
||||||
|
return rule
|
||||||
|
|
||||||
|
|
||||||
def _set_transform_attributes(
|
def _set_transform_attributes(
|
||||||
@ -110,6 +116,86 @@ def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms:
|
|||||||
return None if in_transforms is None else ([in_transforms], [])
|
return None if in_transforms is None else ([in_transforms], [])
|
||||||
|
|
||||||
|
|
||||||
|
@partial(_add_transform_inference_rule, vector.LoadOp)
|
||||||
|
@partial(_add_transform_inference_rule, vector.StoreOp)
|
||||||
|
def _infer_vector_load_store_transforms(
|
||||||
|
op: vector.LoadOp | vector.StoreOp,
|
||||||
|
) -> OptionalTransforms:
|
||||||
|
for i in op.indices:
|
||||||
|
index_defining_op = i.owner.opview
|
||||||
|
if (
|
||||||
|
not isinstance(index_defining_op, arith.ConstantOp)
|
||||||
|
or index_defining_op.literal_value != 0
|
||||||
|
):
|
||||||
|
# TODO(bchetioui): handle slicing.
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Only constants with value 0 are supported as indices for {op}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(op, vector.LoadOp):
|
||||||
|
[layout_attr] = inference_utils.out_layouts(op)
|
||||||
|
else:
|
||||||
|
assert isinstance(op, vector.StoreOp)
|
||||||
|
[layout_attr] = inference_utils.in_layouts(op)
|
||||||
|
|
||||||
|
layout = layouts_lib.from_layout_attr(layout_attr)
|
||||||
|
transforms = inference_utils.value_transforms(op.base)
|
||||||
|
|
||||||
|
if layout == fa.WGMMA_LAYOUT:
|
||||||
|
layout_transforms = infer_transforms_for_wgmma_ref(
|
||||||
|
ir.MemRefType(op.base.type)
|
||||||
|
)
|
||||||
|
elif (isinstance(layout, fa.WGStridedFragLayout) or
|
||||||
|
isinstance(layout, fa.WGSplatFragLayout)):
|
||||||
|
layout_transforms = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Got layout {layout} which is not yet supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
if transforms is not None and layout_transforms is not None:
|
||||||
|
if transforms != layout_transforms:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Conflicting transforms for {op.base} in {op}: "
|
||||||
|
f"{transforms} != {layout_transforms}."
|
||||||
|
)
|
||||||
|
return [transforms], []
|
||||||
|
|
||||||
|
if transforms is not None:
|
||||||
|
return [transforms], []
|
||||||
|
|
||||||
|
if layout_transforms is not None:
|
||||||
|
return [layout_transforms], []
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
||||||
|
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
||||||
|
|
||||||
|
@partial(_add_transform_inference_rule, SliceSMEMOp)
|
||||||
|
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
|
||||||
|
transforms = None
|
||||||
|
uses = cast(ir.OpResult, op.result).uses
|
||||||
|
|
||||||
|
for op_operand_use in uses:
|
||||||
|
consumer = op_operand_use.owner
|
||||||
|
op_user = consumer.operands[op_operand_use.operand_number]
|
||||||
|
out_transforms = inference_utils.in_transforms_for_operand(
|
||||||
|
consumer, op_user
|
||||||
|
)
|
||||||
|
if transforms is not None and out_transforms is not None:
|
||||||
|
if transforms != out_transforms:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Conflicting transforms for {op_user} in {op}: "
|
||||||
|
f"{transforms} != {out_transforms}."
|
||||||
|
)
|
||||||
|
elif out_transforms is not None:
|
||||||
|
transforms = out_transforms
|
||||||
|
|
||||||
|
return None if transforms is None else ([], [transforms])
|
||||||
|
|
||||||
|
|
||||||
def _should_have_transforms(op: ir.OpView) -> bool:
|
def _should_have_transforms(op: ir.OpView) -> bool:
|
||||||
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
||||||
return any(
|
return any(
|
||||||
|
@ -346,8 +346,11 @@ def bitwidth_impl(ty: ir.Type):
|
|||||||
return ir.IntegerType(ty).width
|
return ir.IntegerType(ty).width
|
||||||
if ir.FloatType.isinstance(ty):
|
if ir.FloatType.isinstance(ty):
|
||||||
return ir.FloatType(ty).width
|
return ir.FloatType(ty).width
|
||||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
|
||||||
return MBARRIER_BYTES * 8
|
return MBARRIER_BYTES * 8
|
||||||
|
if ir.VectorType.isinstance(ty):
|
||||||
|
vty = ir.VectorType(ty)
|
||||||
|
return math.prod(vty.shape) * bitwidth(vty.element_type)
|
||||||
raise NotImplementedError(ty)
|
raise NotImplementedError(ty)
|
||||||
|
|
||||||
|
|
||||||
@ -1180,13 +1183,33 @@ def shfl_bfly(x: ir.Value, distance: int | ir.Value):
|
|||||||
i32 = ir.IntegerType.get_signless(32)
|
i32 = ir.IntegerType.get_signless(32)
|
||||||
if isinstance(distance, int):
|
if isinstance(distance, int):
|
||||||
distance = c(distance, i32)
|
distance = c(distance, i32)
|
||||||
assert x.type == i32
|
if (result_type := x.type) != i32:
|
||||||
return nvvm.shfl_sync(
|
x = bitcast(x, i32)
|
||||||
|
y = nvvm.shfl_sync(
|
||||||
i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly,
|
i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly,
|
||||||
)
|
)
|
||||||
|
return bitcast(y, result_type)
|
||||||
|
|
||||||
|
|
||||||
|
def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value):
|
||||||
|
i32 = ir.IntegerType.get_signless(32)
|
||||||
|
if (result_type := high.type) != low.type:
|
||||||
|
raise ValueError(f"Types must match, got {high.type} and {low.type}")
|
||||||
|
if high.type != i32:
|
||||||
|
high = bitcast(high, i32)
|
||||||
|
if low.type != i32:
|
||||||
|
low = bitcast(low, i32)
|
||||||
|
if permutation.type != i32:
|
||||||
|
permutation = bitcast(permutation, i32)
|
||||||
|
result = llvm.inline_asm(
|
||||||
|
i32, [high, low, permutation], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
|
||||||
|
)
|
||||||
|
return bitcast(result, result_type)
|
||||||
|
|
||||||
|
|
||||||
def bitcast(x: ir.Value, new_type: ir.Type):
|
def bitcast(x: ir.Value, new_type: ir.Type):
|
||||||
|
if x.type == new_type:
|
||||||
|
return x
|
||||||
if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type):
|
if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type):
|
||||||
new_type = ir.IntegerType(new_type)
|
new_type = ir.IntegerType(new_type)
|
||||||
x_ty = ir.VectorType(x.type)
|
x_ty = ir.VectorType(x.type)
|
||||||
@ -1200,8 +1223,50 @@ def bitcast(x: ir.Value, new_type: ir.Type):
|
|||||||
x_ty = ir.IntegerType(x.type)
|
x_ty = ir.IntegerType(x.type)
|
||||||
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
|
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
|
||||||
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
|
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
|
||||||
|
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
|
||||||
|
x_ty = ir.VectorType(x.type)
|
||||||
|
new_ty = ir.VectorType(new_type)
|
||||||
|
if bitwidth(x_ty) != bitwidth(new_ty):
|
||||||
|
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||||
|
return vector.bitcast(new_type, x)
|
||||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||||
|
|
||||||
|
|
||||||
def ceil_div(x: int, y: int):
|
def ceil_div(x: int, y: int):
|
||||||
return (x + y - 1) // y
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
def vector_slice(v: ir.Value, s: slice):
|
||||||
|
v_ty = ir.VectorType(v.type)
|
||||||
|
if len(v_ty.shape) != 1:
|
||||||
|
raise NotImplementedError(v_ty)
|
||||||
|
[v_len] = v_ty.shape
|
||||||
|
slice_length = len(range(v_len)[s])
|
||||||
|
return vector.extract_strided_slice(
|
||||||
|
ir.VectorType.get((slice_length,), v_ty.element_type),
|
||||||
|
v, [s.start or 0], [slice_length], [1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
|
||||||
|
index = ir.IndexType.get()
|
||||||
|
if not vectors:
|
||||||
|
raise ValueError("Cannot concatenate an empty list of vectors")
|
||||||
|
vty = vectors[0].type
|
||||||
|
if not ir.VectorType.isinstance(vty):
|
||||||
|
raise ValueError("Cannot concatenate non-vector values")
|
||||||
|
if vty.rank != 1:
|
||||||
|
raise NotImplementedError("Only 1D vectors are supported")
|
||||||
|
for v in vectors:
|
||||||
|
if v.type != vty:
|
||||||
|
raise ValueError("Cannot concatenate vectors of different types")
|
||||||
|
result = llvm.mlir_undef(
|
||||||
|
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
|
||||||
|
)
|
||||||
|
offset = 0
|
||||||
|
for v in vectors:
|
||||||
|
for i in range(vty.shape[0]):
|
||||||
|
elem = vector.extractelement(v, position=c(i, index))
|
||||||
|
result = vector.insertelement(elem, result, position=c(offset + i, index))
|
||||||
|
offset += vty.shape[0]
|
||||||
|
return result
|
||||||
|
@ -259,6 +259,8 @@ def wgmma(
|
|||||||
The refs must be contiguous or be contiguous except for having their two minor
|
The refs must be contiguous or be contiguous except for having their two minor
|
||||||
dimensions swapped.
|
dimensions swapped.
|
||||||
"""
|
"""
|
||||||
|
if swizzle == 16:
|
||||||
|
raise NotImplementedError("No swizzle is not supported")
|
||||||
# Step 1. Establish the shape and element type of the operation.
|
# Step 1. Establish the shape and element type of the operation.
|
||||||
if not ir.MemRefType.isinstance(b.type):
|
if not ir.MemRefType.isinstance(b.type):
|
||||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||||
|
@ -214,6 +214,8 @@ nanobind_extension(
|
|||||||
module_name = "utils",
|
module_name = "utils",
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/cleanup",
|
"@com_google_absl//absl/cleanup",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@nanobind",
|
"@nanobind",
|
||||||
|
@ -65,6 +65,7 @@ cc_library(
|
|||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:TransformUtils",
|
"@llvm-project//mlir:TransformUtils",
|
||||||
|
"@llvm-project//mlir:VectorDialect",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -238,11 +238,12 @@ NB_MODULE(_mosaic_gpu_ext, m) {
|
|||||||
"failed to enable tracking of kernel activity by CUPTI");
|
"failed to enable tracking of kernel activity by CUPTI");
|
||||||
});
|
});
|
||||||
m.def("_cupti_get_timings", []() {
|
m.def("_cupti_get_timings", []() {
|
||||||
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
|
THROW_IF_CUPTI_ERROR(
|
||||||
"failed to unsubscribe from CUPTI");
|
cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED),
|
||||||
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
|
|
||||||
"failed to flush CUPTI activity buffers");
|
"failed to flush CUPTI activity buffers");
|
||||||
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
|
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
|
||||||
|
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
|
||||||
|
"failed to unsubscribe from CUPTI");
|
||||||
return profiler_state.timings;
|
return profiler_state.timings;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "jaxlib/mosaic/gpu/passes.h"
|
#include "jaxlib/mosaic/gpu/passes.h"
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -23,6 +24,7 @@ limitations under the License.
|
|||||||
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
|
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||||
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/include/mlir/IR/SymbolTable.h"
|
#include "mlir/include/mlir/IR/SymbolTable.h"
|
||||||
@ -36,6 +38,49 @@ namespace gpu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
|
||||||
|
struct ConvertExtractStridedSlicePattern final
|
||||||
|
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
mlir::LogicalResult matchAndRewrite(
|
||||||
|
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
|
||||||
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto vty = op.getSourceVectorType();
|
||||||
|
if (vty.getRank() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
|
||||||
|
}
|
||||||
|
int64_t size =
|
||||||
|
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||||
|
if (size < 0) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "size is negative");
|
||||||
|
}
|
||||||
|
int64_t start =
|
||||||
|
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||||
|
int64_t stride =
|
||||||
|
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||||
|
if (stride != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
|
||||||
|
}
|
||||||
|
if (start < 0 || start + size > vty.getShape()[0]) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
|
||||||
|
}
|
||||||
|
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
|
||||||
|
op.getLoc(), op.getResult().getType());
|
||||||
|
for (int64_t i = 0; i < size; ++i) {
|
||||||
|
result = rewriter.create<mlir::LLVM::InsertElementOp>(
|
||||||
|
op.getLoc(), result,
|
||||||
|
rewriter.create<mlir::LLVM::ExtractElementOp>(
|
||||||
|
op.getLoc(), subst.getVector(),
|
||||||
|
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||||
|
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
|
||||||
|
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||||
|
op.getLoc(), rewriter.getI32IntegerAttr(i)));
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ConvertGpuToLLVMPass
|
class ConvertGpuToLLVMPass
|
||||||
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
|
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
|
||||||
public:
|
public:
|
||||||
@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
|
|||||||
});
|
});
|
||||||
auto symtab = mlir::SymbolTable(getOperation());
|
auto symtab = mlir::SymbolTable(getOperation());
|
||||||
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
|
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
|
||||||
|
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
|
||||||
if (mlir::applyPartialConversion(getOperation(), target,
|
if (mlir::applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))
|
std::move(patterns))
|
||||||
.failed()) {
|
.failed()) {
|
||||||
|
@ -16,9 +16,13 @@ limitations under the License.
|
|||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "nanobind/nanobind.h"
|
#include "nanobind/nanobind.h"
|
||||||
#include "absl/cleanup/cleanup.h"
|
#include "absl/cleanup/cleanup.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = {
|
|||||||
METH_FASTCALL,
|
METH_FASTCALL,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
nb::list TopologicalSort(nb::str parents_attr,
|
||||||
|
nb::iterable end_nodes_iterable) {
|
||||||
|
// This is a direct conversion of the original Python implementation.
|
||||||
|
// More efficient implementations of a topological sort are possible (and
|
||||||
|
// indeed, easier to write), but changing the choice of topological order
|
||||||
|
// would break existing tests.
|
||||||
|
std::vector<nb::object> end_nodes;
|
||||||
|
absl::flat_hash_set<PyObject*> seen;
|
||||||
|
for (nb::handle n : end_nodes_iterable) {
|
||||||
|
nb::object node = nb::borrow(n);
|
||||||
|
if (seen.insert(node.ptr()).second) {
|
||||||
|
end_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nb::list sorted_nodes;
|
||||||
|
if (end_nodes.empty()) {
|
||||||
|
return sorted_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<nb::object> stack = end_nodes;
|
||||||
|
absl::flat_hash_map<PyObject*, int> child_counts;
|
||||||
|
while (!stack.empty()) {
|
||||||
|
nb::object node = std::move(stack.back());
|
||||||
|
stack.pop_back();
|
||||||
|
auto& count = child_counts[node.ptr()];
|
||||||
|
if (count == 0) {
|
||||||
|
for (nb::handle parent : node.attr(parents_attr)) {
|
||||||
|
stack.push_back(nb::borrow(parent));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (nb::handle n : end_nodes) {
|
||||||
|
child_counts[n.ptr()] -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<nb::object> childless_nodes;
|
||||||
|
childless_nodes.reserve(end_nodes.size());
|
||||||
|
for (nb::handle n : end_nodes) {
|
||||||
|
if (child_counts[n.ptr()] == 0) {
|
||||||
|
childless_nodes.push_back(nb::borrow(n));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!childless_nodes.empty()) {
|
||||||
|
nb::object node = std::move(childless_nodes.back());
|
||||||
|
childless_nodes.pop_back();
|
||||||
|
sorted_nodes.append(node);
|
||||||
|
for (nb::handle parent : node.attr(parents_attr)) {
|
||||||
|
auto& count = child_counts[parent.ptr()];
|
||||||
|
if (count == 1) {
|
||||||
|
childless_nodes.push_back(nb::borrow(parent));
|
||||||
|
} else {
|
||||||
|
--count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sorted_nodes.reverse();
|
||||||
|
return sorted_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
NB_MODULE(utils, m) {
|
NB_MODULE(utils, m) {
|
||||||
@ -304,6 +371,13 @@ NB_MODULE(utils, m) {
|
|||||||
m.attr("safe_zip") = nb::steal<nb::object>(
|
m.attr("safe_zip") = nb::steal<nb::object>(
|
||||||
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
|
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
|
||||||
|
|
||||||
|
m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"),
|
||||||
|
nb::arg("end_nodes"),
|
||||||
|
"Computes a topological sort of a graph of objects. parents_attr is "
|
||||||
|
"the name of the attribute on each object that contains the list of "
|
||||||
|
"parent objects. end_nodes is an iterable of objects from which we "
|
||||||
|
"should start a backwards search.");
|
||||||
|
|
||||||
// Python has no reader-writer lock in its standard library, so we expose
|
// Python has no reader-writer lock in its standard library, so we expose
|
||||||
// bindings around absl::Mutex.
|
// bindings around absl::Mutex.
|
||||||
nb::class_<absl::Mutex>(m, "Mutex")
|
nb::class_<absl::Mutex>(m, "Mutex")
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for AOT compilation."""
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import unittest
|
import unittest
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for jax.api_util."""
|
|
||||||
|
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for Array."""
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import math
|
import math
|
||||||
|
@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase):
|
|||||||
self.assertEqual(ans.names, expected.names)
|
self.assertEqual(ans.names, expected.names)
|
||||||
self.assertAllClose(ans.data, expected.data)
|
self.assertAllClose(ans.data, expected.data)
|
||||||
|
|
||||||
|
def test_types_with_same_spec(self):
|
||||||
|
# We register NamedArray.
|
||||||
|
batching.register_vmappable(NamedArray, NamedMapSpec, int,
|
||||||
|
named_to_elt, named_from_elt, None)
|
||||||
|
|
||||||
|
# We then register another type that uses NamedMapSpec as the spec_type too,
|
||||||
|
# and immediately unregister it.
|
||||||
|
class Foo:
|
||||||
|
pass
|
||||||
|
batching.register_vmappable(Foo, NamedMapSpec, int,
|
||||||
|
named_to_elt, named_from_elt, None)
|
||||||
|
batching.unregister_vmappable(Foo)
|
||||||
|
|
||||||
|
# We should still be able to use vmap on NamedArray.
|
||||||
|
def f(x):
|
||||||
|
return named_mul(x, x)
|
||||||
|
|
||||||
|
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
||||||
|
ans = jax.jit(f)(x)
|
||||||
|
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
||||||
|
|
||||||
|
self.assertEqual(ans.names, expected.names)
|
||||||
|
self.assertAllClose(ans.data, expected.data)
|
||||||
|
|
||||||
|
# And unregister NamedArray without exceptions.
|
||||||
|
batching.unregister_vmappable(NamedArray)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
@ -37,18 +37,41 @@ def call_kernel(
|
|||||||
m, n = grid
|
m, n = grid
|
||||||
return jnp.concatenate([
|
return jnp.concatenate([
|
||||||
jnp.concatenate([
|
jnp.concatenate([
|
||||||
kernel(i, j, *args) for j in range(n)], axis=1)
|
kernel((i, j), *args) for j in range(n)], axis=1)
|
||||||
for i in range(m)], axis=0)
|
for i in range(m)], axis=0)
|
||||||
|
|
||||||
|
|
||||||
def uniform_kernel(i: int, j: int, total_size, block_size, tile_size):
|
def call_kernel_3d(
|
||||||
"""Uniform random sampling kernel function."""
|
kernel,
|
||||||
global_key = jax.random.key(0)
|
grid: tuple[int, int],
|
||||||
keys = blocked_sampler.blocked_fold_in(global_key,
|
*args
|
||||||
|
):
|
||||||
|
"""Calls a kernel over a 3D grid and concatenates results to a single array."""
|
||||||
|
depth, rows, cols = grid
|
||||||
|
return jnp.concatenate([
|
||||||
|
jnp.concatenate([
|
||||||
|
jnp.concatenate([
|
||||||
|
jnp.array(kernel((i, j, k), *args))
|
||||||
|
for k in range(cols)], axis=2)
|
||||||
|
for j in range(rows)], axis=1)
|
||||||
|
for i in range(depth)], axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def blocked_fold_in(block_index, key, total_size, block_size, tile_size):
|
||||||
|
"""Folds in block_index into global_key."""
|
||||||
|
return blocked_sampler.blocked_fold_in(key,
|
||||||
total_size=total_size,
|
total_size=total_size,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
tile_size=tile_size,
|
tile_size=tile_size,
|
||||||
block_index=(i, j))
|
block_index=block_index)
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_kernel(block_index, key, total_size, block_size, tile_size):
|
||||||
|
"""Uniform random sampling kernel function."""
|
||||||
|
keys = blocked_fold_in(block_index, key,
|
||||||
|
total_size=total_size,
|
||||||
|
block_size=block_size,
|
||||||
|
tile_size=tile_size)
|
||||||
return blocked_sampler.sample_block(jax.random.uniform,
|
return blocked_sampler.sample_block(jax.random.uniform,
|
||||||
keys,
|
keys,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
|
|||||||
)
|
)
|
||||||
def test_block_shape_invariance(self, total_size, block_size_a,
|
def test_block_shape_invariance(self, total_size, block_size_a,
|
||||||
block_size_b, tile_size, transpose_grid):
|
block_size_b, tile_size, transpose_grid):
|
||||||
|
global_key = jax.random.key(0)
|
||||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||||
result_a = call_kernel(
|
result_a = call_kernel(
|
||||||
uniform_kernel, grid_a, transpose_grid,
|
uniform_kernel, grid_a, transpose_grid, global_key,
|
||||||
total_size, block_size_a, tile_size)
|
total_size, block_size_a, tile_size)
|
||||||
|
|
||||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||||
result_b = call_kernel(
|
result_b = call_kernel(
|
||||||
uniform_kernel, grid_b, transpose_grid,
|
uniform_kernel, grid_b, transpose_grid, global_key,
|
||||||
total_size, block_size_b, tile_size)
|
total_size, block_size_b, tile_size)
|
||||||
np.testing.assert_array_equal(result_a, result_b)
|
np.testing.assert_array_equal(result_a, result_b)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockedFoldInTest(jtu.JaxTestCase):
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
# Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
|
||||||
|
# as expected. Specifically, blocked key folding does not depend on the total
|
||||||
|
# size of the tensor, but only the total number of tiles.
|
||||||
|
# Using a 3D grid (with very large inner dimensions) triggers an overflow in a
|
||||||
|
# previous implementation of blocked_fold_in.
|
||||||
|
dict(testcase_name='4096x512_vs_1024x2048',
|
||||||
|
total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512),
|
||||||
|
block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)),
|
||||||
|
)
|
||||||
|
def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a,
|
||||||
|
block_size_b, tile_size):
|
||||||
|
global_key = jax.random.key(0)
|
||||||
|
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||||
|
result_a = call_kernel_3d(
|
||||||
|
blocked_fold_in, grid_a, global_key, total_size,
|
||||||
|
block_size_a, tile_size)
|
||||||
|
|
||||||
|
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||||
|
result_b = call_kernel_3d(
|
||||||
|
blocked_fold_in, grid_b, global_key, total_size,
|
||||||
|
block_size_b, tile_size)
|
||||||
|
np.testing.assert_array_equal(jax.random.key_data(result_a),
|
||||||
|
jax.random.key_data(result_b))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for release_backend_clients."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import jax
|
import jax
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for --debug_nans."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
@ -20,12 +20,14 @@ from jax._src import config
|
|||||||
from jax._src import error_check
|
from jax._src import error_check
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||||
|
|
||||||
|
|
||||||
JaxValueError = error_check.JaxValueError
|
JaxValueError = error_check.JaxValueError
|
||||||
|
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
jtu.request_cpu_devices(4)
|
||||||
|
|
||||||
|
|
||||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||||
@ -190,6 +192,23 @@ class ErrorCheckTests(jtu.JaxTestCase):
|
|||||||
):
|
):
|
||||||
jax.jit(error_check.raise_if_error)()
|
jax.jit(error_check.raise_if_error)()
|
||||||
|
|
||||||
|
@parameterized.product(jit=[True, False])
|
||||||
|
@jtu.with_user_mesh((2, 2), ("x", "y"))
|
||||||
|
def test_error_check_explicit_mode(self, mesh, jit):
|
||||||
|
def f(x):
|
||||||
|
error_check.set_error_if(x <= 0, "x must be greater than 0")
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
if jit:
|
||||||
|
f = jax.jit(f)
|
||||||
|
|
||||||
|
sharding = NamedSharding(mesh, P("x", "y"))
|
||||||
|
x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding)
|
||||||
|
with error_check.error_checking_context():
|
||||||
|
f(x)
|
||||||
|
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
|
||||||
|
error_check.raise_if_error()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for garbage allocation guard."""
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import weakref
|
import weakref
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for jax.numpy.ufunc and its methods."""
|
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase):
|
|||||||
x = lax.optimization_barrier((2, 3))
|
x = lax.optimization_barrier((2, 3))
|
||||||
self.assertEqual((2, 3), x)
|
self.assertEqual((2, 3), x)
|
||||||
|
|
||||||
|
def test_optimization_barrier_autodiff(self):
|
||||||
|
def f(x):
|
||||||
|
y = 1. * x
|
||||||
|
x, y = lax.optimization_barrier((x, y))
|
||||||
|
z = 2. * x
|
||||||
|
return y + z
|
||||||
|
g = jax.grad(f)(5.) # doesn't crash
|
||||||
|
self.assertAllClose(g, 3., check_dtypes=False)
|
||||||
|
|
||||||
|
|
||||||
class LazyConstantTest(jtu.JaxTestCase):
|
class LazyConstantTest(jtu.JaxTestCase):
|
||||||
def _Check(self, make_const, expected):
|
def _Check(self, make_const, expected):
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for the LAPAX linear algebra module."""
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for mesh utils."""
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
@ -74,6 +74,37 @@ class LayoutInferenceTest(parameterized.TestCase):
|
|||||||
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
||||||
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
||||||
|
|
||||||
|
def test_infer_strided_layout_from_shape_cast(self):
|
||||||
|
shape = (16, 8)
|
||||||
|
elt_type = ir.BF16Type.get()
|
||||||
|
src_type = ir.VectorType.get(shape, elt_type)
|
||||||
|
dst_type = ir.VectorType.get([*reversed(shape)], elt_type)
|
||||||
|
op = None
|
||||||
|
|
||||||
|
def body(x):
|
||||||
|
nonlocal op
|
||||||
|
op = vector.ShapeCastOp(dst_type, x)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
func.FuncOp.from_py_func(src_type)(body)
|
||||||
|
|
||||||
|
mgpu.infer_layout(self.module)
|
||||||
|
|
||||||
|
in_layout = layouts.to_layout_attr(
|
||||||
|
mgpu.WGStridedFragLayout.from_shaped_type(src_type)
|
||||||
|
)
|
||||||
|
out_layout = layouts.to_layout_attr(
|
||||||
|
mgpu.WGStridedFragLayout.from_shaped_type(dst_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout])
|
||||||
|
self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout])
|
||||||
|
|
||||||
|
# Ensure that we can recover the original layout.
|
||||||
|
del op.attributes["in_layouts"]
|
||||||
|
mgpu.infer_layout(self.module)
|
||||||
|
self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout])
|
||||||
|
|
||||||
def test_infer_splat_layout_for_splat_constants(self):
|
def test_infer_splat_layout_for_splat_constants(self):
|
||||||
shape = (16, 8)
|
shape = (16, 8)
|
||||||
elt_type = ir.BF16Type.get()
|
elt_type = ir.BF16Type.get()
|
||||||
|
@ -12,9 +12,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for Mosaic GPU DSL functions and utilities."""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
@ -84,6 +84,20 @@ def mlir_sum(elems):
|
|||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def get_sass():
|
||||||
|
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
||||||
|
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
||||||
|
try:
|
||||||
|
with jtu.capture_stdout() as output:
|
||||||
|
yield output
|
||||||
|
finally:
|
||||||
|
if prev_dump is not None:
|
||||||
|
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
||||||
|
else:
|
||||||
|
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
||||||
|
|
||||||
|
|
||||||
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
||||||
index = ir.IndexType.get()
|
index = ir.IndexType.get()
|
||||||
thread_id = gpu.thread_id(gpu.Dimension.x)
|
thread_id = gpu.thread_id(gpu.Dimension.x)
|
||||||
@ -519,14 +533,38 @@ class WGMMALayoutTest(TestCase):
|
|||||||
)()
|
)()
|
||||||
np.testing.assert_array_equal(iota, expected)
|
np.testing.assert_array_equal(iota, expected)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.parameters(jnp.int8, jnp.int16, jnp.int32)
|
||||||
("bf16_i8", jnp.bfloat16, jnp.int8),
|
def test_sub_byte_conversion(self, jax_dtype_to):
|
||||||
("i8_bf16", jnp.int8, jnp.bfloat16),
|
jax_dtype_from = jnp.int4
|
||||||
("i8_i8", jnp.int8, jnp.int8),
|
def kernel(ctx, inp, out, smem):
|
||||||
("i4_i4", jnp.int4, jnp.int4),
|
del ctx # Unused.
|
||||||
("i4_bf16", jnp.int4, jnp.bfloat16),
|
smem_inp, smem_out = smem
|
||||||
|
copy(inp, smem_inp, swizzle=16)
|
||||||
|
t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16)
|
||||||
|
t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True)
|
||||||
|
t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
|
||||||
|
copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
|
||||||
|
|
||||||
|
x = self.prng.integers(
|
||||||
|
low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32
|
||||||
|
).astype(jax_dtype_from)
|
||||||
|
y = x.astype(jax_dtype_to)
|
||||||
|
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
|
||||||
|
np.testing.assert_array_equal(f(x), y)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
jax_dtype_from_to=(
|
||||||
|
(jnp.int8, jnp.bfloat16),
|
||||||
|
(jnp.int4, jnp.bfloat16),
|
||||||
|
),
|
||||||
|
layout=(
|
||||||
|
fa.WGMMA_LAYOUT,
|
||||||
|
fa.WGMMA_LAYOUT_UPCAST_2X,
|
||||||
|
fa.WGMMA_LAYOUT_UPCAST_4X,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
|
def test_optimized_conversion(self, jax_dtype_from_to, layout):
|
||||||
|
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
|
||||||
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
|
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
|
||||||
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
||||||
m = 128
|
m = 128
|
||||||
@ -539,7 +577,7 @@ class WGMMALayoutTest(TestCase):
|
|||||||
smem_from,
|
smem_from,
|
||||||
swizzle=128,
|
swizzle=128,
|
||||||
is_signed=utils.is_signed(jax_dtype_from),
|
is_signed=utils.is_signed(jax_dtype_from),
|
||||||
layout=fa._tiled_wgmma_layout((m, n))
|
layout=layout,
|
||||||
)
|
)
|
||||||
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
|
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
|
||||||
t.store_tiled(smem_to, swizzle=128)
|
t.store_tiled(smem_to, swizzle=128)
|
||||||
@ -2175,19 +2213,11 @@ class LayoutTest(TestCase):
|
|||||||
.transpose(0, 2, 1, 3)
|
.transpose(0, 2, 1, 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
with get_sass() as sass:
|
||||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
|
||||||
try:
|
|
||||||
with jtu.capture_stdout() as get_sass:
|
|
||||||
iota = mgpu.as_gpu_kernel(
|
iota = mgpu.as_gpu_kernel(
|
||||||
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
||||||
[expected, expected, mgpu.TMABarrier()],
|
[expected, expected, mgpu.TMABarrier()],
|
||||||
)(expected)
|
)(expected)
|
||||||
finally:
|
|
||||||
if prev_dump is not None:
|
|
||||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
|
||||||
else:
|
|
||||||
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
|
||||||
np.testing.assert_array_equal(iota, expected)
|
np.testing.assert_array_equal(iota, expected)
|
||||||
|
|
||||||
# Verify that we don't use too many registers for the transfers.
|
# Verify that we don't use too many registers for the transfers.
|
||||||
@ -2200,7 +2230,7 @@ class LayoutTest(TestCase):
|
|||||||
expected_regs //= 2
|
expected_regs //= 2
|
||||||
for instr in ("STS", "LDS"):
|
for instr in ("STS", "LDS"):
|
||||||
with self.subTest(instr + " count"):
|
with self.subTest(instr + " count"):
|
||||||
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
|
addrs = re.findall(instr + r".* \[(.*)\]", sass())
|
||||||
def get_reg(addr):
|
def get_reg(addr):
|
||||||
if (pos := addr.find("+")) != -1:
|
if (pos := addr.find("+")) != -1:
|
||||||
return addr[:pos]
|
return addr[:pos]
|
||||||
@ -2214,13 +2244,13 @@ class LayoutTest(TestCase):
|
|||||||
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
|
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
|
||||||
m, n = 128, col_tiling * 2
|
m, n = 128, col_tiling * 2
|
||||||
tiling = (64, col_tiling)
|
tiling = (64, col_tiling)
|
||||||
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
|
layout = fa.WGMMA_LAYOUT_UPCAST_2X
|
||||||
def kernel(ctx, in_, out, smems):
|
def kernel(ctx, in_, out, smems):
|
||||||
smem_in, smem_out, barrier = smems
|
smem_in, smem_out, barrier = smems
|
||||||
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
||||||
barrier.wait()
|
barrier.wait()
|
||||||
t = mgpu.FragmentedArray.load_tiled(
|
t = mgpu.FragmentedArray.load_tiled(
|
||||||
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
|
smem_in, swizzle=swizzle, is_signed=True, layout=layout
|
||||||
)
|
)
|
||||||
t.store_tiled(smem_out, swizzle=swizzle)
|
t.store_tiled(smem_out, swizzle=swizzle)
|
||||||
mgpu.commit_shared()
|
mgpu.commit_shared()
|
||||||
@ -2275,6 +2305,61 @@ class LayoutTest(TestCase):
|
|||||||
)(x)
|
)(x)
|
||||||
np.testing.assert_array_equal(y, y_ref)
|
np.testing.assert_array_equal(y, y_ref)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1),
|
||||||
|
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1),
|
||||||
|
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1),
|
||||||
|
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
|
||||||
|
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
|
||||||
|
)
|
||||||
|
def test_upcast_to_wgmma(
|
||||||
|
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
|
||||||
|
):
|
||||||
|
in_dtype = jnp.dtype(in_dtype)
|
||||||
|
out_dtype = jnp.dtype(jnp.int16)
|
||||||
|
out_dtype_mlir = utils.dtype_to_ir_type(out_dtype)
|
||||||
|
swizzle = 128
|
||||||
|
in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits
|
||||||
|
in_tiling = (8, in_col_tiling)
|
||||||
|
out_col_tiling = swizzle // out_dtype.itemsize
|
||||||
|
out_tiling = (8, out_col_tiling)
|
||||||
|
m, n = 128, in_col_tiling * 2
|
||||||
|
regs_per_thread = None
|
||||||
|
def kernel(ctx, in_, out, smems):
|
||||||
|
nonlocal regs_per_thread
|
||||||
|
smem_in, smem_out, barrier = smems
|
||||||
|
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
||||||
|
barrier.wait()
|
||||||
|
t = mgpu.FragmentedArray.load_tiled(
|
||||||
|
smem_in, swizzle=swizzle, is_signed=True, layout=start_layout
|
||||||
|
)
|
||||||
|
regs_per_thread = t.registers.size
|
||||||
|
t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True)
|
||||||
|
t = t.to_layout(end_layout)
|
||||||
|
t = t.astype(out_dtype_mlir, is_signed=True)
|
||||||
|
t.store_tiled(smem_out, swizzle=swizzle)
|
||||||
|
mgpu.commit_shared()
|
||||||
|
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
||||||
|
ctx.await_async_copy(0)
|
||||||
|
def tile(x, tiling):
|
||||||
|
return x.reshape(
|
||||||
|
x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1]
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
in_iinfo = jnp.iinfo(in_dtype)
|
||||||
|
x = jax.random.randint(
|
||||||
|
jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32
|
||||||
|
).astype(in_dtype)
|
||||||
|
xt = tile(x, in_tiling)
|
||||||
|
y = x.astype(out_dtype)
|
||||||
|
yt = tile(y, out_tiling)
|
||||||
|
f = mgpu.as_gpu_kernel(
|
||||||
|
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
|
||||||
|
)
|
||||||
|
with get_sass() as sass:
|
||||||
|
yt_kernel = f(xt)
|
||||||
|
np.testing.assert_array_equal(yt_kernel, yt)
|
||||||
|
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Tile:
|
class Tile:
|
||||||
|
@ -25,8 +25,11 @@ from jax._src.interpreters import mlir as mlir_interpreter
|
|||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib.mlir.dialects import arith
|
from jax._src.lib.mlir.dialects import arith
|
||||||
from jax._src.lib.mlir.dialects import func
|
from jax._src.lib.mlir.dialects import func
|
||||||
|
from jax._src.lib.mlir.dialects import vector
|
||||||
import jax.experimental.mosaic.gpu as mgpu
|
import jax.experimental.mosaic.gpu as mgpu
|
||||||
|
from jax.experimental.mosaic.gpu import fragmented_array as fa
|
||||||
from jax.experimental.mosaic.gpu import inference_utils
|
from jax.experimental.mosaic.gpu import inference_utils
|
||||||
|
from jax.experimental.mosaic.gpu import layouts as layouts_lib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@ -162,6 +165,259 @@ class TransformInferenceTest(parameterized.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEmpty(inference_utils.out_transforms(async_store_op))
|
self.assertEmpty(inference_utils.out_transforms(async_store_op))
|
||||||
|
|
||||||
|
def test_infer_transforms_for_vector_load_op_derives_from_destination(self):
|
||||||
|
vector_load_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
|
||||||
|
def body(smem_ref):
|
||||||
|
nonlocal vector_load_op
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
vector_load_op = vector.LoadOp(
|
||||||
|
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||||
|
func.FuncOp.from_py_func(smem_ty)(body)
|
||||||
|
|
||||||
|
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||||
|
)
|
||||||
|
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
expected_transforms = ir.ArrayAttr.get([
|
||||||
|
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||||
|
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
inference_utils.in_transforms(vector_load_op), [expected_transforms]
|
||||||
|
)
|
||||||
|
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
|
||||||
|
|
||||||
|
def test_infer_transforms_for_vector_load_op_derives_from_source(self):
|
||||||
|
vector_load_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
|
||||||
|
def body(smem_ref):
|
||||||
|
nonlocal vector_load_op
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
vector_load_op = vector.LoadOp(
|
||||||
|
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||||
|
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
|
||||||
|
|
||||||
|
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||||
|
)
|
||||||
|
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||||
|
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||||
|
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
inference_utils.in_transforms(vector_load_op), [transforms]
|
||||||
|
)
|
||||||
|
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
|
||||||
|
|
||||||
|
def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self):
|
||||||
|
vector_load_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
|
||||||
|
def body(smem_ref):
|
||||||
|
nonlocal vector_load_op
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
vector_load_op = vector.LoadOp(
|
||||||
|
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||||
|
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
|
||||||
|
|
||||||
|
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||||
|
)
|
||||||
|
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||||
|
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
def test_infer_transforms_for_vector_store_op_derives_from_destination(self):
|
||||||
|
vector_store_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
|
||||||
|
def body(smem_ref, value_to_store):
|
||||||
|
nonlocal vector_store_op
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
vector_store_op = vector.StoreOp(
|
||||||
|
value_to_store, smem_ref, [zero] * len(shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||||
|
value_ty = ir.VectorType.get(shape, elt_ty)
|
||||||
|
func.FuncOp.from_py_func(smem_ty, value_ty)(body)
|
||||||
|
|
||||||
|
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||||
|
)
|
||||||
|
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
expected_transforms = ir.ArrayAttr.get([
|
||||||
|
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||||
|
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
inference_utils.in_transforms(vector_store_op), [expected_transforms]
|
||||||
|
)
|
||||||
|
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
|
||||||
|
|
||||||
|
def test_infer_transforms_for_vector_store_op_derives_from_source(self):
|
||||||
|
vector_store_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
|
||||||
|
def body(smem_ref, value_to_store):
|
||||||
|
nonlocal vector_store_op
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
vector_store_op = vector.StoreOp(
|
||||||
|
value_to_store, smem_ref, [zero] * len(shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||||
|
value_ty = ir.VectorType.get(shape, elt_ty)
|
||||||
|
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
|
||||||
|
|
||||||
|
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||||
|
)
|
||||||
|
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||||
|
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||||
|
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
inference_utils.in_transforms(vector_store_op), [transforms]
|
||||||
|
)
|
||||||
|
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
|
||||||
|
|
||||||
|
def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self):
|
||||||
|
vector_store_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
|
||||||
|
def body(smem_ref, value_to_store):
|
||||||
|
nonlocal vector_store_op
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
vector_store_op = vector.StoreOp(
|
||||||
|
value_to_store, smem_ref, [zero] * len(shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||||
|
value_ty = ir.VectorType.get(shape, elt_ty)
|
||||||
|
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
|
||||||
|
|
||||||
|
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||||
|
)
|
||||||
|
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||||
|
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
def test_infer_transforms_for_slice_smem_op_derives_from_user(self):
|
||||||
|
slice_smem_op = vector_load_op = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
|
||||||
|
def body(offset):
|
||||||
|
nonlocal slice_smem_op, vector_load_op
|
||||||
|
slice_smem_op = mgpu.dialect.SliceSMEMOp(
|
||||||
|
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
|
||||||
|
)
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
load_offsets = [zero] * len(shape)
|
||||||
|
vector_load_op = vector.LoadOp(
|
||||||
|
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
|
||||||
|
|
||||||
|
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||||
|
)
|
||||||
|
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
expected_transforms = ir.ArrayAttr.get([
|
||||||
|
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||||
|
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertEmpty(inference_utils.in_transforms(slice_smem_op))
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
inference_utils.out_transforms(slice_smem_op), [expected_transforms]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self):
|
||||||
|
slice_smem_op = vector_load_op1 = vector_load_op2 = None
|
||||||
|
shape = (64, 64)
|
||||||
|
elt_ty = ir.BF16Type.get()
|
||||||
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||||
|
|
||||||
|
def body(offset):
|
||||||
|
nonlocal slice_smem_op, vector_load_op1, vector_load_op2
|
||||||
|
slice_smem_op = mgpu.dialect.SliceSMEMOp(
|
||||||
|
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
|
||||||
|
)
|
||||||
|
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||||
|
load_offsets = [zero] * len(shape)
|
||||||
|
vector_load_op1 = vector.LoadOp(
|
||||||
|
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||||
|
)
|
||||||
|
vector_load_op2 = vector.LoadOp(
|
||||||
|
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
with ir.InsertionPoint(self.module.body):
|
||||||
|
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
|
||||||
|
|
||||||
|
vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||||
|
)
|
||||||
|
vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||||
|
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||||
|
)
|
||||||
|
vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get(
|
||||||
|
[ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])]
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||||
|
mgpu.infer_transforms(self.module)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for Mosaic GPU CUPTI-based profiler."""
|
|
||||||
|
|
||||||
from absl.testing import absltest, parameterized
|
from absl.testing import absltest, parameterized
|
||||||
import jax
|
import jax
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for nn module."""
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for the optimizers module."""
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for pull block spec."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import jax
|
import jax
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for Pallas indexing logic and abstractions."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
@ -185,7 +185,7 @@ class PallasCallTest(PallasTest):
|
|||||||
np.testing.assert_array_equal(kernel(x, y), x + y[0])
|
np.testing.assert_array_equal(kernel(x, y), x + y[0])
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics]
|
shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics]
|
||||||
)
|
)
|
||||||
def test_reduce_sum(self, shape, thread_semantics):
|
def test_reduce_sum(self, shape, thread_semantics):
|
||||||
@functools.partial(
|
@functools.partial(
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for common JAX operations within pallas_call."""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for Pallas error handling."""
|
|
||||||
import functools
|
import functools
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for TPU specific operations within pallas_call."""
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for distributed pallas TPU operations."""
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for random ops in Pallas + Mosaic."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for Pallas mesh API."""
|
|
||||||
import functools
|
import functools
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import jax
|
import jax
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for splash_attention."""
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for splash_attention_masks."""
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for interoperability between JAX and pickling libraries."""
|
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
|
@ -6138,6 +6138,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
|||||||
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
|
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
|
||||||
{AxisType.Auto: ('x',)})
|
{AxisType.Auto: ('x',)})
|
||||||
|
|
||||||
|
@jtu.with_user_mesh((2,), 'x')
|
||||||
|
def test_device_put_use_mesh(self, mesh):
|
||||||
|
out = jax.device_put(np.arange(8), P('x'))
|
||||||
|
self.assertArraysEqual(out, np.arange(8))
|
||||||
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||||
|
|
||||||
|
def test_device_put_no_use_mesh_error(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is'
|
||||||
|
' passed to device_put'):
|
||||||
|
jax.device_put(np.arange(8), P('x'))
|
||||||
|
|
||||||
@jtu.with_user_mesh((2,), 'x')
|
@jtu.with_user_mesh((2,), 'x')
|
||||||
def test_inputs_different_context(self, mesh):
|
def test_inputs_different_context(self, mesh):
|
||||||
np_inp = np.arange(16).reshape(8, 2)
|
np_inp = np.arange(16).reshape(8, 2)
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
|
|
||||||
"""Tests for the library of QDWH-based polar decomposition."""
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
|
|||||||
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
||||||
self._CheckChiSquared(samples, pmf=pmf)
|
self._CheckChiSquared(samples, pmf=pmf)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
logits_shape=[(7,), (8, 9), (10, 11, 12)],
|
||||||
|
prefix_shape=[(2,), (3, 4), (5, 6)],
|
||||||
|
)
|
||||||
|
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
|
||||||
|
key = random.key(0)
|
||||||
|
|
||||||
|
key, subkey = random.split(key)
|
||||||
|
logits = random.normal(subkey, logits_shape)
|
||||||
|
|
||||||
|
key, subkey = random.split(key)
|
||||||
|
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
|
||||||
|
|
||||||
|
dists_shape = tuple(np.delete(logits_shape, axis))
|
||||||
|
n_categories = logits_shape[axis]
|
||||||
|
shape = prefix_shape + dists_shape
|
||||||
|
prefix_size = math.prod(prefix_shape)
|
||||||
|
|
||||||
|
if n_categories < prefix_size:
|
||||||
|
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
|
||||||
|
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||||
|
self.assertEqual(output.shape, shape)
|
||||||
|
assert (0 <= output).all()
|
||||||
|
assert (output < n_categories).all()
|
||||||
|
flat = output.reshape((prefix_size, math.prod(dists_shape)))
|
||||||
|
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
|
||||||
|
assert (counts <= 1).all()
|
||||||
|
|
||||||
|
|
||||||
def testBernoulliShape(self):
|
def testBernoulliShape(self):
|
||||||
key = self.make_key(0)
|
key = self.make_key(0)
|
||||||
with jax.numpy_rank_promotion('allow'):
|
with jax.numpy_rank_promotion('allow'):
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for the shape-polymorphic export."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
@ -12,9 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
"""Tests for stack."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for Stax library."""
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
|
|
||||||
"""Tests for the library of QDWH-based singular value decomposition."""
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for transfer guards."""
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import pickle
|
import pickle
|
||||||
|
@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase):
|
|||||||
util.safe_zip((), range(3))
|
util.safe_zip((), range(3))
|
||||||
|
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, parents):
|
||||||
|
self.parents = parents
|
||||||
|
|
||||||
|
|
||||||
|
class TopologicalSortTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
|
def _check_topological_sort(self, nodes, order):
|
||||||
|
self.assertEqual(sorted(nodes, key=id), sorted(order, key=id))
|
||||||
|
visited = set()
|
||||||
|
for node in nodes:
|
||||||
|
self.assertTrue(all(id(parent) in visited for parent in node.parents))
|
||||||
|
visited.add(id(node))
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
a = Node([])
|
||||||
|
b = Node([a])
|
||||||
|
c = Node([a])
|
||||||
|
d = Node([a, c])
|
||||||
|
e = Node([b, c])
|
||||||
|
out = util.toposort([a, d, e])
|
||||||
|
self._check_topological_sort([a, b, c, d, e], out)
|
||||||
|
|
||||||
|
def test_stick(self):
|
||||||
|
a = Node([])
|
||||||
|
b = Node([a])
|
||||||
|
c = Node([b])
|
||||||
|
d = Node([c])
|
||||||
|
e = Node([d])
|
||||||
|
out = util.toposort([e])
|
||||||
|
self._check_topological_sort([a, b, c, d, e], out)
|
||||||
|
|
||||||
|
def test_diamonds(self):
|
||||||
|
a = Node([])
|
||||||
|
b = Node([a])
|
||||||
|
c = Node([a])
|
||||||
|
d = Node([b, c])
|
||||||
|
e = Node([d])
|
||||||
|
f = Node([d])
|
||||||
|
g = Node([e, f])
|
||||||
|
out = util.toposort([g])
|
||||||
|
self._check_topological_sort([a, b, c, d, e, f, g], out)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
|||||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||||
# and update XLA_SHA256 with the result.
|
# and update XLA_SHA256 with the result.
|
||||||
|
|
||||||
XLA_COMMIT = "4c4aa96f9ffec4bb963b50c50192aeab4da9dc4a"
|
XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438"
|
||||||
XLA_SHA256 = "c373e52b2f8b4175c69e99e636ad64b3bcf33fb44d1b7ad6ef8f4162c9052af8"
|
XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497"
|
||||||
|
|
||||||
def repo():
|
def repo():
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user