mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #294 from ROCm/ci-upstream-sync-151_1
CI: 03/18/25 upstream sync
This commit is contained in:
commit
c46b4fc02b
5
.github/workflows/pytest_cpu.yml
vendored
5
.github/workflows/pytest_cpu.yml
vendored
@ -118,6 +118,11 @@ jobs:
|
||||
run: |
|
||||
$JAXCI_PYTHON -m pip install uv~=0.5.30
|
||||
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
|
||||
|
||||
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
|
||||
if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then
|
||||
$JAXCI_PYTHON -m uv pip install numpy~=2.1.0
|
||||
fi
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
|
3
.github/workflows/pytest_cuda.yml
vendored
3
.github/workflows/pytest_cuda.yml
vendored
@ -54,7 +54,8 @@ jobs:
|
||||
runs-on: ${{ inputs.runner }}
|
||||
# TODO: Update to the generic ML ecosystem test containers when they are ready.
|
||||
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
|
||||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
|
||||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
|
||||
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
|
||||
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
||||
|
||||
env:
|
||||
|
24
.github/workflows/wheel_tests_continuous.yml
vendored
24
.github/workflows/wheel_tests_continuous.yml
vendored
@ -110,18 +110,30 @@ jobs:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
# Python values need to match the matrix stategy in the artifact build jobs above
|
||||
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
|
||||
# See exlusions for what is fully tested
|
||||
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
|
||||
python: ["3.10",]
|
||||
cuda: ["12.3", "12.1"]
|
||||
cuda: ["12.1","12.3","12.8"]
|
||||
enable-x64: [1, 0]
|
||||
exclude:
|
||||
# Run only a single configuration on H100 to save resources
|
||||
# L4 does not run on cuda 12.8 but tests other configs
|
||||
- runner: "linux-x86-g2-48-l4-4gpu"
|
||||
cuda: "12.8"
|
||||
# H100 runs only a single config, CUDA 12.3 Enable x64 1
|
||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||
cuda: "12.8"
|
||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||
python: "3.10"
|
||||
cuda: "12.1"
|
||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||
python: "3.10"
|
||||
enable-x64: 0
|
||||
enable-x64: "0"
|
||||
# B200 runs only a single config, CUDA 12.8 Enable x64 1
|
||||
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||
enable-x64: "0"
|
||||
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||
cuda: "12.1"
|
||||
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||
cuda: "12.3"
|
||||
|
||||
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
||||
with:
|
||||
runner: ${{ matrix.runner }}
|
||||
|
@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
||||
true, matching the current behavior. If set to false, JAX does not need to
|
||||
emit code clamping negative indices, which improves code size.
|
||||
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
|
||||
without replacement.
|
||||
|
||||
## jax 0.5.2 (Mar 4, 2025)
|
||||
|
||||
|
@ -18,7 +18,4 @@ setuptools
|
||||
matplotlib~=3.8.4; python_version=="3.10"
|
||||
matplotlib; python_version>="3.11"
|
||||
opt-einsum
|
||||
auditwheel
|
||||
|
||||
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
|
||||
numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64"
|
||||
auditwheel
|
@ -49,13 +49,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "hVi6mApuVw3r",
|
||||
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
|
||||
"id": "hVi6mApuVw3r"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -84,13 +80,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mzDIDvj7Vw0k",
|
||||
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
|
||||
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -119,13 +115,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "IyPx_-IBVwxr",
|
||||
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
|
||||
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -141,7 +137,7 @@
|
||||
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -172,13 +168,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "NO2ulM_QW7a8",
|
||||
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
|
||||
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -208,13 +204,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "1-TzmA0AXCAf",
|
||||
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
|
||||
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -256,13 +252,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Gy7ABds3XND3",
|
||||
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
|
||||
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -297,13 +293,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "grCcotr-XQjY",
|
||||
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
|
||||
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -324,7 +320,7 @@
|
||||
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -460,13 +456,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "fpFEaMBcXsJG",
|
||||
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
|
||||
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -479,13 +475,6 @@
|
||||
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
|
||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@ -550,13 +539,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "geptWrdYX0OM",
|
||||
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
|
||||
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -588,7 +577,88 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AQQjzUeGX4P6"
|
||||
"id": "LZWjgiMZ7uSS"
|
||||
},
|
||||
"source": [
|
||||
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "IVzPSkp77uCF",
|
||||
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
|
||||
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
|
||||
"\n",
|
||||
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
|
||||
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
|
||||
"\n",
|
||||
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
|
||||
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
|
||||
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
|
||||
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import functools\n",
|
||||
"\n",
|
||||
"@functools.partial(auto_axes, axes='X')\n",
|
||||
"def g(y):\n",
|
||||
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
|
||||
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
|
||||
" return y * 2\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def f(arr1):\n",
|
||||
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
|
||||
" x = jnp.sin(arr1)\n",
|
||||
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
|
||||
"\n",
|
||||
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
|
||||
"\n",
|
||||
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
|
||||
" return z + 1\n",
|
||||
"\n",
|
||||
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
|
||||
"f(some_x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_3sfJjRq8w9f"
|
||||
},
|
||||
"source": [
|
||||
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "sJcWbfAh7UcO"
|
||||
},
|
||||
"source": [
|
||||
"## Concrete array shardings can mention `Auto` mesh axis\n",
|
||||
@ -606,7 +676,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -708,5 +778,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you
|
||||
find something that doesn't work!
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: hVi6mApuVw3r
|
||||
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
|
||||
---
|
||||
:id: hVi6mApuVw3r
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
@ -79,7 +75,7 @@ scalar) using `jax.typeof`:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: mzDIDvj7Vw0k
|
||||
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
|
||||
outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a
|
||||
---
|
||||
some_array = np.arange(8)
|
||||
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
|
||||
@ -96,7 +92,7 @@ under a jit).
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: IyPx_-IBVwxr
|
||||
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
|
||||
outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb
|
||||
---
|
||||
@jax.jit
|
||||
def foo(x):
|
||||
@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: NO2ulM_QW7a8
|
||||
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
|
||||
outputId: d888371b-080e-4bff-be5d-ea56beda3aac
|
||||
---
|
||||
mesh = jax.make_mesh((2, 4), ("X", "Y"),
|
||||
axis_types=(AxisType.Explicit, AxisType.Explicit))
|
||||
@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: 1-TzmA0AXCAf
|
||||
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
|
||||
outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2
|
||||
---
|
||||
replicated_array = np.arange(8).reshape(4, 2)
|
||||
sharded_array = reshard(replicated_array, P("X", None))
|
||||
@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: Gy7ABds3XND3
|
||||
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
|
||||
outputId: 0d72dad2-381a-4e96-f771-40d705da1376
|
||||
---
|
||||
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
|
||||
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
|
||||
@ -184,7 +180,7 @@ We can do the same type querying under a jit:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: grCcotr-XQjY
|
||||
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
|
||||
outputId: c2db656c-809f-49a6-c948-629d6420360c
|
||||
---
|
||||
@jax.jit
|
||||
def add_arrays(x, y):
|
||||
@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: fpFEaMBcXsJG
|
||||
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
|
||||
outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef
|
||||
---
|
||||
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
|
||||
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
|
||||
@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: geptWrdYX0OM
|
||||
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
|
||||
outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f
|
||||
---
|
||||
print(f"Current mesh is: {get_abstract_mesh()}")
|
||||
```
|
||||
@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only
|
||||
mention _explicit_ mesh axes and collective operations like `psum` can only
|
||||
mention _manual_ mesh axes.
|
||||
|
||||
+++ {"id": "AQQjzUeGX4P6"}
|
||||
+++ {"id": "LZWjgiMZ7uSS"}
|
||||
|
||||
You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: IVzPSkp77uCF
|
||||
outputId: db80a604-98ac-4343-8677-23729adf7ffc
|
||||
---
|
||||
import functools
|
||||
|
||||
@functools.partial(auto_axes, axes='X')
|
||||
def g(y):
|
||||
print(f'mesh inside g: {get_abstract_mesh()}')
|
||||
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
|
||||
return y * 2
|
||||
|
||||
@jax.jit
|
||||
def f(arr1):
|
||||
print(f'mesh inside f: {get_abstract_mesh()}')
|
||||
x = jnp.sin(arr1)
|
||||
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
|
||||
|
||||
z = g(x, out_shardings=P("X", "Y"))
|
||||
|
||||
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
|
||||
return z + 1
|
||||
|
||||
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
|
||||
f(some_x)
|
||||
```
|
||||
|
||||
+++ {"id": "_3sfJjRq8w9f"}
|
||||
|
||||
As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.
|
||||
|
||||
+++ {"id": "sJcWbfAh7UcO"}
|
||||
|
||||
## Concrete array shardings can mention `Auto` mesh axis
|
||||
|
||||
|
@ -299,7 +299,7 @@
|
||||
" ):\n",
|
||||
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
|
||||
" del idxs_k_ref\n",
|
||||
" blk_idx = pl.program_id(0)\n",
|
||||
" blk_idx = pl.program_id(1)\n",
|
||||
" is_start = blk_idx == 0\n",
|
||||
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
|
||||
" @pl.when(is_start | changed_blocks)\n",
|
||||
@ -314,13 +314,13 @@
|
||||
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del j, blk_idxs_i, blk_idxs_k\n",
|
||||
" return (blk_idx, 0, 0)\n",
|
||||
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del blk_idxs_i\n",
|
||||
" return (blk_idxs_k[blk_idx], j)\n",
|
||||
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del blk_idxs_k\n",
|
||||
" return (blk_idxs_i[blk_idx], j)\n",
|
||||
"\n",
|
||||
@ -335,7 +335,7 @@
|
||||
" num_scalar_prefetch=2,\n",
|
||||
" # Note that while num_blocks is static here, Pallas does support\n",
|
||||
" # dynamic grid sizes.\n",
|
||||
" grid=(num_blocks, N // blk_N),\n",
|
||||
" grid=(N // blk_N, num_blocks),\n",
|
||||
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
||||
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
||||
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
||||
|
@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
||||
):
|
||||
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
|
||||
del idxs_k_ref
|
||||
blk_idx = pl.program_id(0)
|
||||
blk_idx = pl.program_id(1)
|
||||
is_start = blk_idx == 0
|
||||
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
|
||||
@pl.when(is_start | changed_blocks)
|
||||
@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
||||
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
||||
|
||||
|
||||
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
del j, blk_idxs_i, blk_idxs_k
|
||||
return (blk_idx, 0, 0)
|
||||
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
del blk_idxs_i
|
||||
return (blk_idxs_k[blk_idx], j)
|
||||
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
del blk_idxs_k
|
||||
return (blk_idxs_i[blk_idx], j)
|
||||
|
||||
@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=2,
|
||||
# Note that while num_blocks is static here, Pallas does support
|
||||
# dynamic grid sizes.
|
||||
grid=(num_blocks, N // blk_N),
|
||||
grid=(N // blk_N, num_blocks),
|
||||
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
||||
pl.BlockSpec((blk_K, blk_N), y_map),
|
||||
# Placeholder for a zeros-array used by input_output_aliases.
|
||||
|
@ -81,7 +81,7 @@ int main(int argc, char** argv) {
|
||||
xla::XlaComputation xla_computation(test_module_proto);
|
||||
xla::CompileOptions compile_options;
|
||||
std::unique_ptr<xla::PjRtLoadedExecutable> executable =
|
||||
client->Compile(xla_computation, compile_options).value();
|
||||
client->CompileAndLoad(xla_computation, compile_options).value();
|
||||
|
||||
// Prepare inputs.
|
||||
xla::Literal literal_x =
|
||||
|
@ -799,7 +799,7 @@ pytype_strict_library(
|
||||
)
|
||||
|
||||
# This target only supports sm_90 GPUs.
|
||||
py_library(
|
||||
py_library_providing_imports_info(
|
||||
name = "mosaic_gpu",
|
||||
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
||||
visibility = [
|
||||
@ -824,6 +824,7 @@ py_library(
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:scf_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic/python:gpu_dialect",
|
||||
] + py_deps("absl/flags") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
|
@ -67,7 +67,9 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.mesh import get_concrete_mesh
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding)
|
||||
from jax._src.layout import Layout, AutoLayout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
@ -2280,11 +2282,20 @@ def _check_sharding(aval, s):
|
||||
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
|
||||
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
||||
|
||||
def pspec_to_sharding(val):
|
||||
if isinstance(val, P):
|
||||
mesh = get_concrete_mesh()
|
||||
if mesh is None:
|
||||
raise ValueError(
|
||||
"Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is"
|
||||
" passed to device_put")
|
||||
return NamedSharding(mesh, val)
|
||||
return val
|
||||
|
||||
def device_put(
|
||||
x,
|
||||
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
||||
device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
|
||||
donate: bool | Any = False, may_alias: bool | None | Any = None):
|
||||
"""Transfers ``x`` to ``device``.
|
||||
|
||||
@ -2333,6 +2344,9 @@ def device_put(
|
||||
src_flat = flatten_axes("device_put source", treedef, src)
|
||||
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
|
||||
|
||||
device_flat = map(pspec_to_sharding, device_flat)
|
||||
src_flat = map(pspec_to_sharding, src_flat)
|
||||
|
||||
if isinstance(donate, bool):
|
||||
donate_flat = [donate] * len(x_flat)
|
||||
else:
|
||||
|
@ -28,17 +28,17 @@ class SampleFn(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def _compute_scalar_index(iteration_index: Sequence[int],
|
||||
total_size: Shape,
|
||||
block_size: Shape,
|
||||
block_index: Sequence[int]) -> int:
|
||||
ndims = len(iteration_index)
|
||||
def _compute_tile_index(block_index: Sequence[int],
|
||||
total_size_in_blocks: Shape,
|
||||
block_size_in_tiles: Shape,
|
||||
tile_index_in_block: Sequence[int]) -> int:
|
||||
ndims = len(block_index)
|
||||
dim_size = 1
|
||||
total_idx = 0
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = block_index[i] + iteration_index[i] * block_size[i]
|
||||
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size[i]
|
||||
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
|
||||
return total_idx
|
||||
|
||||
|
||||
@ -99,18 +99,23 @@ def blocked_fold_in(
|
||||
An N-dimensional nested list of keys required to sample the tiles
|
||||
corresponding to the block specified by `block_index`.
|
||||
"""
|
||||
size_in_blocks = tuple(
|
||||
_shape // _element for _shape, _element in zip(block_size, tile_size))
|
||||
block_size_in_tiles = tuple(
|
||||
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
||||
)
|
||||
|
||||
total_size_in_blocks = tuple(
|
||||
_shape // _element for _shape, _element in zip(total_size, block_size)
|
||||
)
|
||||
|
||||
def _keygen_loop(axis, prefix):
|
||||
if axis == len(size_in_blocks):
|
||||
if axis == len(block_size_in_tiles):
|
||||
subtile_key = jax.random.fold_in(
|
||||
global_key, _compute_scalar_index(
|
||||
block_index, total_size, size_in_blocks, prefix))
|
||||
global_key, _compute_tile_index(
|
||||
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
|
||||
return subtile_key
|
||||
else:
|
||||
keys = []
|
||||
for i in range(size_in_blocks[axis]):
|
||||
for i in range(block_size_in_tiles[axis]):
|
||||
keys.append(_keygen_loop(axis+1, prefix+(i,)))
|
||||
return keys
|
||||
return _keygen_loop(0, tuple())
|
||||
|
@ -446,7 +446,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
if len(devices) == 1:
|
||||
# If we only have one device in our computation, we can construct a
|
||||
# replicated HloSharding and call it right now.
|
||||
_hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding())
|
||||
_hlo_sharding_callback(sharding_impls.replicated_hlo_sharding)
|
||||
return []
|
||||
|
||||
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)
|
||||
|
@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy):
|
||||
if not s.is_fully_addressable:
|
||||
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
||||
type(x) in array_types):
|
||||
multihost_utils.assert_equal(
|
||||
x, fail_message=(
|
||||
f"{type(x)} passed to device_put is not the same on each"
|
||||
" process. Make sure you are passing the same value of"
|
||||
f" {type(x)} on each process."))
|
||||
# TODO(emilyaf): Remove this condition when jit works when a sharding
|
||||
# has no local devices.
|
||||
if not config.enable_empty_arrays.value:
|
||||
multihost_utils.assert_equal(
|
||||
x, fail_message=(
|
||||
f"{type(x)} passed to device_put is not the same on each"
|
||||
" process. Make sure you are passing the same value of"
|
||||
f" {type(x)} on each process."))
|
||||
return _DeferredShardArg(x, s, aval, True, copy)
|
||||
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
|
||||
raise ValueError(
|
||||
|
@ -14,13 +14,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import threading
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
import jax._src.mesh as mesh_lib
|
||||
from jax.experimental.shard_map import shard_map
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
|
||||
|
||||
Traceback = source_info_util.Traceback
|
||||
@ -54,17 +58,61 @@ _error_storage = _ErrorStorage()
|
||||
|
||||
|
||||
def _initialize_error_code_ref() -> None:
|
||||
"""Initialize error_code_ref in the current thread."""
|
||||
"""Initialize error_code_ref in the current thread.
|
||||
|
||||
The size of the error code array is determined by the mesh in the context. In
|
||||
single-device environment, the array is a scalar. In multi-device
|
||||
environment, the array has the same shape as the mesh.
|
||||
"""
|
||||
with core.eval_context():
|
||||
error_code = jnp.uint32(_NO_ERROR)
|
||||
# Get mesh from the context.
|
||||
mesh = mesh_lib.get_concrete_mesh()
|
||||
|
||||
if mesh is None: # single-device case.
|
||||
error_code = jnp.uint32(_NO_ERROR)
|
||||
|
||||
else: # multi-device case.
|
||||
sharding = NamedSharding(mesh, P(*mesh.axis_names))
|
||||
error_code = jnp.full(
|
||||
mesh.axis_sizes,
|
||||
jnp.uint32(_NO_ERROR),
|
||||
device=sharding,
|
||||
)
|
||||
|
||||
_error_storage.ref = core.mutable_array(error_code)
|
||||
|
||||
|
||||
def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
class error_checking_context:
|
||||
"""Redefine the error checking state based on the mesh in the context.
|
||||
|
||||
This context manager should be used when starting a multi-device
|
||||
computation, and whenever the mesh is changed.
|
||||
|
||||
When exiting the context, the error checking state will be reset to the
|
||||
original state.
|
||||
"""
|
||||
|
||||
__slots__ = ("old_ref",)
|
||||
|
||||
def __init__(self):
|
||||
self.old_ref = None
|
||||
|
||||
def __enter__(self):
|
||||
self.old_ref = _error_storage.ref
|
||||
_initialize_error_code_ref()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
_error_storage.ref = self.old_ref
|
||||
|
||||
|
||||
def set_error_if(pred: jax.Array, /, msg: str) -> None:
|
||||
"""Set error if any element of pred is true.
|
||||
|
||||
If the error is already set, the new error will be ignored. It will not
|
||||
override the existing error.
|
||||
|
||||
In auto mode, this function does not work under jit.
|
||||
"""
|
||||
if _error_storage.ref is None:
|
||||
_initialize_error_code_ref()
|
||||
@ -76,7 +124,32 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
new_error_code = jnp.uint32(len(_error_list))
|
||||
_error_list.append((msg, traceback))
|
||||
|
||||
pred = pred.any()
|
||||
out_sharding = core.typeof(_error_storage.ref).sharding
|
||||
in_sharding: NamedSharding = core.typeof(pred).sharding
|
||||
|
||||
if out_sharding.mesh.shape_tuple == (): # single-device case.
|
||||
pred = pred.any()
|
||||
else: # multi-device case.
|
||||
has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types
|
||||
if has_auto_axes:
|
||||
raise NotImplementedError(
|
||||
"Error checking in auto mode is not supported yet. Please use"
|
||||
" explicit mode."
|
||||
)
|
||||
if out_sharding.mesh != in_sharding.mesh:
|
||||
raise ValueError(
|
||||
"The error code state and the predicate must be on the same mesh, "
|
||||
f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. "
|
||||
"Please use `with error_checking_context()` to redefine the error "
|
||||
"code state based on the mesh."
|
||||
)
|
||||
pred = shard_map(
|
||||
partial(jnp.any, keepdims=True),
|
||||
mesh=out_sharding.mesh,
|
||||
in_specs=in_sharding.spec,
|
||||
out_specs=out_sharding.spec,
|
||||
)(pred) # perform per-device reduction
|
||||
|
||||
error_code = _error_storage.ref[...]
|
||||
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
|
||||
error_code = jnp.where(should_update, new_error_code, error_code)
|
||||
@ -93,7 +166,7 @@ def raise_if_error() -> None:
|
||||
if _error_storage.ref is None: # if not initialized, do nothing
|
||||
return
|
||||
|
||||
error_code = _error_storage.ref[...]
|
||||
error_code = _error_storage.ref[...].min() # reduce to a single error code
|
||||
if isinstance(error_code, core.Tracer):
|
||||
raise ValueError(
|
||||
"raise_if_error() should not be called within a traced context, such as"
|
||||
@ -101,7 +174,11 @@ def raise_if_error() -> None:
|
||||
)
|
||||
if error_code == jnp.uint32(_NO_ERROR):
|
||||
return
|
||||
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
|
||||
_error_storage.ref[...] = jnp.full(
|
||||
_error_storage.ref.shape,
|
||||
jnp.uint32(_NO_ERROR),
|
||||
device=_error_storage.ref.sharding,
|
||||
) # clear the error code
|
||||
|
||||
msg, traceback = _error_list[error_code]
|
||||
exc = JaxValueError(msg)
|
||||
|
@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {}
|
||||
spec_types: set[type] = {JumbleAxis}
|
||||
|
||||
def unregister_vmappable(data_type: type) -> None:
|
||||
spec_type, axis_size_type = vmappables.pop(data_type)
|
||||
spec_types.remove(spec_type)
|
||||
_, axis_size_type = vmappables.pop(data_type)
|
||||
del to_elt_handlers[data_type]
|
||||
del from_elt_handlers[data_type]
|
||||
if axis_size_type in make_iota_handlers:
|
||||
del make_iota_handlers[axis_size_type]
|
||||
global spec_types
|
||||
spec_types = (
|
||||
{JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
|
||||
)
|
||||
|
||||
def is_vmappable(x: Any) -> bool:
|
||||
return type(x) is Jumble or type(x) in vmappables
|
||||
|
@ -797,7 +797,7 @@ def tracers_to_jaxpr(
|
||||
|
||||
processed_eqn_ids = set()
|
||||
eqns: list[core.JaxprEqn] = []
|
||||
for t in toposort([*in_tracers, *out_tracers]):
|
||||
for t in toposort((*in_tracers, *out_tracers)):
|
||||
r = t.recipe
|
||||
if isinstance(r, JaxprEqnRecipe):
|
||||
# TODO broadcast_in_dim can create a new tracer, not present in parents
|
||||
|
@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
if len(bufs) == len(xs) > 0:
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||
|
@ -1026,24 +1026,101 @@ def clz(x: ArrayLike) -> Array:
|
||||
r"""Elementwise count-leading-zeros."""
|
||||
return clz_p.bind(x)
|
||||
|
||||
@export
|
||||
def add(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise addition: :math:`x + y`."""
|
||||
r"""Elementwise addition: :math:`x + y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.add`_ operation.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||
and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the sum
|
||||
of each pair of broadcasted entries.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.add`: NumPy-style addition supporting inputs
|
||||
with mixed dtypes and ranks.
|
||||
|
||||
.. _stablehlo.add: https://openxla.org/stablehlo/spec#add
|
||||
"""
|
||||
return add_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def sub(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise subtraction: :math:`x - y`."""
|
||||
r"""Elementwise subtraction: :math:`x - y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.subtract`_ operation.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||
and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the difference
|
||||
of each pair of broadcasted entries.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.subtract`: NumPy-style subtraction supporting
|
||||
inputs with mixed dtypes and ranks.
|
||||
|
||||
.. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract
|
||||
"""
|
||||
return sub_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def mul(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise multiplication: :math:`x \times y`."""
|
||||
r"""Elementwise multiplication: :math:`x \times y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.multiply`_ operation.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||
and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the product
|
||||
of each pair of broadcasted entries.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.multiply`: NumPy-style multiplication supporting
|
||||
inputs with mixed dtypes and ranks.
|
||||
|
||||
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
||||
"""
|
||||
return mul_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def div(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise division: :math:`x \over y`.
|
||||
|
||||
Integer division overflow
|
||||
(division by zero or signed division of INT_SMIN with -1)
|
||||
produces an implementation defined value.
|
||||
This function lowers directly to the `stablehlo.divide`_ operation.
|
||||
|
||||
Integer division overflow (division by zero or signed division of
|
||||
INT_SMIN with -1) produces an implementation defined value.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
||||
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||
and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the quotient
|
||||
of each pair of broadcasted entries. For integer inputs, any fractional
|
||||
part is discarded.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.divide`: NumPy-style true division supporting
|
||||
inputs with mixed dtypes and ranks.
|
||||
- :func:`jax.numpy.floor_divide`: NumPy-style floor division supporting
|
||||
inputs with mixed dtypes and ranks.
|
||||
|
||||
.. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide
|
||||
"""
|
||||
return div_p.bind(x, y)
|
||||
|
||||
@ -8422,3 +8499,13 @@ mlir.register_lowering(optimization_barrier_p,
|
||||
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
|
||||
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
|
||||
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
|
||||
|
||||
def _opt_barrier_jvp(primals, tangents):
|
||||
tangents = [ad.instantiate_zeros(t) for t in tangents]
|
||||
return optimization_barrier(primals), optimization_barrier(tangents)
|
||||
ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp
|
||||
|
||||
def _opt_barrier_transpose(cts, *primals):
|
||||
cts = [ad.instantiate_zeros(ct) for ct in cts]
|
||||
return optimization_barrier(cts)
|
||||
ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose
|
||||
|
@ -565,5 +565,5 @@ def use_concrete_mesh(mesh: Mesh | None):
|
||||
finally:
|
||||
jax_config.device_context.set_local(prev_val)
|
||||
|
||||
def get_concrete_mesh():
|
||||
def get_concrete_mesh() -> Mesh | None:
|
||||
return jax_config.device_context.value
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
import contextlib
|
||||
import copy
|
||||
@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
|
||||
return [], effs
|
||||
|
||||
|
||||
class Mesh(Protocol):
|
||||
|
||||
@property
|
||||
def backend(self) -> str:
|
||||
...
|
||||
|
||||
@property
|
||||
def shape(self) -> collections.OrderedDict[object, int]:
|
||||
...
|
||||
|
||||
|
||||
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
|
||||
|
||||
|
||||
@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule(
|
||||
in_avals,
|
||||
out_avals,
|
||||
*args,
|
||||
grid,
|
||||
mesh,
|
||||
compiler_params,
|
||||
backend,
|
||||
jaxpr,
|
||||
debug,
|
||||
interpret,
|
||||
@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule(
|
||||
if isinstance(eff, state_types.WriteEffect)
|
||||
)
|
||||
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
|
||||
grid_spec = GridSpec(
|
||||
grid=tuple(mesh.shape.items()),
|
||||
in_specs=[any_spec] * len(in_avals),
|
||||
out_specs=[any_spec] * len(modified_idxs),
|
||||
)
|
||||
from jax._src.pallas import pallas_call # Avoid circular dependency.
|
||||
outs = pallas_call.pallas_call(
|
||||
outs = pallas_call._pallas_call(
|
||||
body,
|
||||
name=name,
|
||||
out_shape=[in_avals[idx] for idx in modified_idxs],
|
||||
in_specs=[any_spec] * len(in_avals),
|
||||
out_specs=[any_spec] * len(modified_idxs),
|
||||
input_output_aliases={
|
||||
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
|
||||
},
|
||||
grid=grid,
|
||||
grid_spec=grid_spec,
|
||||
mesh=mesh,
|
||||
compiler_params=compiler_params,
|
||||
backend=backend,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
cost_estimate=cost_estimate,
|
||||
|
@ -340,11 +340,12 @@ def pallas_call_hlo_interpret(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
del compiler_params, cost_estimate, out_avals
|
||||
del mesh, compiler_params, cost_estimate, out_avals
|
||||
debug_info = jaxpr.debug_info
|
||||
# If we're in interpret mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
|
@ -211,6 +211,10 @@ class TensorCoreMesh:
|
||||
devices: np.ndarray
|
||||
axis_names: Sequence[str]
|
||||
|
||||
@property
|
||||
def backend(self) -> str:
|
||||
return "mosaic_tpu"
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
|
||||
@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule(
|
||||
compiler_params = TPUCompilerParams()
|
||||
if len(mesh.shape) > 1:
|
||||
raise NotImplementedError("Mesh must be 1D")
|
||||
core_axis_name, num_cores = list(mesh.shape.items())[0]
|
||||
if compiler_params.dimension_semantics is not None:
|
||||
raise ValueError(
|
||||
"dimension_semantics must be None for TensorCoreMesh"
|
||||
@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule(
|
||||
out_avals,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid=((core_axis_name, num_cores),),
|
||||
mesh=mesh,
|
||||
compiler_params=compiler_params.replace(
|
||||
dimension_semantics=(PARALLEL,)
|
||||
),
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
backend="mosaic_tpu",
|
||||
cost_estimate=cost_estimate,
|
||||
name=name,
|
||||
)
|
||||
|
@ -1351,12 +1351,13 @@ def interpret_pallas_call(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
interpret_params: TPUInterpretParams,
|
||||
):
|
||||
del debug, cost_estimate, out_avals
|
||||
del debug, mesh, cost_estimate, out_avals
|
||||
|
||||
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
|
||||
dynamic_grid_args, scalars, input_args = split_list(
|
||||
|
@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
del interpret
|
||||
del mesh, interpret # Unused.
|
||||
|
||||
debug_info = jaxpr._debug_info
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
|
||||
@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
else:
|
||||
mosaic_params = {}
|
||||
|
||||
mesh = None
|
||||
jax_mesh = None
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if axis_context is not None:
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
mesh = axis_context.mesh
|
||||
jax_mesh = axis_context.mesh
|
||||
mlir_ctx = mlir.JaxIrContext()
|
||||
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||
mlir_ctx.load_all_available_dialects()
|
||||
@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
grid_mapping,
|
||||
jaxpr,
|
||||
dimension_semantics=dimension_semantics,
|
||||
mesh=mesh,
|
||||
mesh=jax_mesh,
|
||||
for_verification=for_verification,
|
||||
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
|
||||
)
|
||||
@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
)
|
||||
|
||||
if promela_dump_path := _DUMP_PROMELA_TO.value:
|
||||
num_devices = 1 if mesh is None else mesh.devices.size
|
||||
num_devices = 1 if jax_mesh is None else jax_mesh.devices.size
|
||||
num_cores = (
|
||||
jax.devices()[0].num_cores
|
||||
if mesh is None
|
||||
else mesh.devices[0].num_cores
|
||||
if jax_mesh is None
|
||||
else jax_mesh.devices[0].num_cores
|
||||
)
|
||||
verification_module, _ = lower_module(for_verification=True)
|
||||
model = verification.export_promela_model(
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools as it
|
||||
@ -519,9 +519,16 @@ class GPUMesh:
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
def backend(self) -> str:
|
||||
return "mosaic_gpu"
|
||||
|
||||
@property
|
||||
def shape(self) -> collections.OrderedDict[object, int]:
|
||||
pairs: Iterable[tuple[object, int]]
|
||||
if self.num_threads is not None:
|
||||
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
|
||||
pairs = zip(
|
||||
self.axis_names, (*self.grid, *self.cluster, self.num_threads)
|
||||
)
|
||||
else:
|
||||
pairs = tuple(
|
||||
zip(
|
||||
@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule(
|
||||
out_avals,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid=tuple(mesh.shape.items()),
|
||||
backend="mosaic_gpu",
|
||||
mesh=mesh,
|
||||
compiler_params=compiler_params,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
|
@ -450,6 +450,7 @@ def _block_spec_from_block_mapping(
|
||||
|
||||
def lower_pipelined_jaxpr_to_module(
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module(
|
||||
block_mappings, [grid_mapping.num_inputs]
|
||||
)
|
||||
|
||||
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
|
||||
if mesh is not None:
|
||||
assert isinstance(mesh, gpu_core.GPUMesh)
|
||||
if mesh and mesh.num_threads is not None:
|
||||
# Last dim corresponds to the warpgroup count.
|
||||
block = (128 * grid_mapping.grid[-1], 1, 1)
|
||||
grid = grid_mapping.grid[:-1]
|
||||
else:
|
||||
@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module(
|
||||
parallel_grid,
|
||||
grid_mapping.grid_names,
|
||||
block,
|
||||
mesh.cluster if mesh is not None else (),
|
||||
[bm.array_shape_dtype for bm in in_block_mappings],
|
||||
[bm.array_shape_dtype for bm in out_block_mappings],
|
||||
new_jaxpr,
|
||||
@ -578,6 +583,7 @@ def lower_jaxpr_to_module(
|
||||
grid: Sequence[int],
|
||||
grid_names: Sequence[str],
|
||||
block: Sequence[int],
|
||||
cluster: Sequence[int],
|
||||
in_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||
out_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
@ -640,7 +646,7 @@ def lower_jaxpr_to_module(
|
||||
mgpu_core._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid=parallel_grid,
|
||||
cluster=(),
|
||||
cluster=cluster,
|
||||
block=block,
|
||||
in_shapes=in_shapes,
|
||||
out_shape=out_shapes,
|
||||
@ -1559,9 +1565,10 @@ def _reduce_lowering_rule_wg(
|
||||
if not out_aval.shape:
|
||||
# Special-case: reducing to a scalar.
|
||||
if x_aval.ndim != 1:
|
||||
# TODO(slebedev): Flatten to 1D, since vector.reduction only supports
|
||||
# 1D inputs.
|
||||
raise NotImplementedError("Only 1D inputs are supported")
|
||||
# Flatten to 1D, since vector.reduction only supports 1D inputs.
|
||||
x = vector_dialect.shape_cast(
|
||||
ir.VectorType.get([x_aval.size], out_type), x
|
||||
)
|
||||
return vector_dialect.ReductionOp(out_type, kind, x)
|
||||
acc = vector_dialect.splat(
|
||||
ir.VectorType.get(out_aval.shape, out_type),
|
||||
|
@ -38,6 +38,7 @@ def pallas_call_lowering(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
@ -63,6 +64,7 @@ def pallas_call_lowering(
|
||||
|
||||
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
mesh,
|
||||
jaxpr,
|
||||
compiler_params,
|
||||
cost_estimate,
|
||||
|
@ -20,7 +20,7 @@ import dataclasses
|
||||
import enum
|
||||
from functools import partial, reduce
|
||||
import types
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -119,6 +119,7 @@ def _pallas_call_jvp_rule(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
@ -133,6 +134,8 @@ def _pallas_call_jvp_rule(
|
||||
raise NotImplementedError
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("JVP with aliasing not supported.")
|
||||
if mesh is not None:
|
||||
raise NotImplementedError("pallas_call with a mesh does not support JVP")
|
||||
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||||
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
|
||||
@ -181,6 +184,7 @@ def _pallas_call_jvp_rule(
|
||||
*tangents,
|
||||
jaxpr=jvp_jaxpr,
|
||||
grid_mapping=jvp_grid_mapping,
|
||||
mesh=mesh,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
@ -317,6 +321,7 @@ def _batch_with_explicit_loop(
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
@ -384,6 +389,7 @@ def _batch_with_explicit_loop(
|
||||
*batch_args,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -413,6 +419,7 @@ def _pallas_call_batching_rule(
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
@ -421,6 +428,11 @@ def _pallas_call_batching_rule(
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None,
|
||||
):
|
||||
if mesh is not None:
|
||||
raise NotImplementedError(
|
||||
"pallas_call with a mesh does not support batching"
|
||||
)
|
||||
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
) -> jax.Array:
|
||||
@ -445,6 +457,7 @@ def _pallas_call_batching_rule(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -478,6 +491,7 @@ def _pallas_call_batching_rule(
|
||||
dims=dynamic_grid_dims + dims,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -512,6 +526,7 @@ def _pallas_call_batching_rule(
|
||||
dims=scalar_bdims + bdims,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -890,6 +905,7 @@ def _pallas_call_batching_rule(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
backend: _Backend | None = None
|
||||
backend: _Backend | None = None,
|
||||
):
|
||||
del avals_out
|
||||
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
||||
@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule(
|
||||
jaxpr=new_jaxpr,
|
||||
input_output_aliases=new_input_output_aliases,
|
||||
grid_mapping=new_grid_mapping,
|
||||
mesh=mesh,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
@ -1526,16 +1544,6 @@ def pallas_call(
|
||||
invoke the Pallas kernel.
|
||||
|
||||
"""
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
||||
raise ValueError(
|
||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
|
||||
compiler_params = {
|
||||
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
||||
}
|
||||
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
|
||||
else:
|
||||
@ -1556,6 +1564,55 @@ def pallas_call(
|
||||
"If `grid_spec` is specified, then `scratch_shapes` must "
|
||||
f"be `()`. It is {scratch_shapes}")
|
||||
del grid, in_specs, out_specs
|
||||
return _pallas_call(
|
||||
kernel,
|
||||
out_shape,
|
||||
grid_spec=grid_spec,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name=name,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def _pallas_call(
|
||||
kernel: Callable[..., None],
|
||||
out_shape: Any,
|
||||
*,
|
||||
grid_spec: GridSpec,
|
||||
mesh: pallas_core.Mesh | None = None,
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
debug: bool = False,
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
backend: _Backend | None = None,
|
||||
):
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
||||
raise ValueError(
|
||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}"
|
||||
)
|
||||
compiler_params = {
|
||||
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
||||
}
|
||||
|
||||
if mesh is not None:
|
||||
if tuple(mesh.shape.values()) != grid_spec.grid:
|
||||
raise ValueError(
|
||||
f"Mesh shape {tuple(mesh.shape.values())} does not match grid "
|
||||
f"shape {grid_spec.grid}."
|
||||
)
|
||||
if backend is not None:
|
||||
raise ValueError("If `mesh` is specified, then `backend` must be `None`.")
|
||||
backend = cast(_Backend, mesh.backend)
|
||||
|
||||
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
||||
# TODO(necula): this canonicalization may be convenient for some usage
|
||||
# but it is lossy, because it prevents expressing functions that return
|
||||
@ -1643,6 +1700,7 @@ def pallas_call(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
mesh=mesh,
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
|
@ -50,6 +50,7 @@ def pallas_call_lowering(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
@ -64,6 +65,8 @@ def pallas_call_lowering(
|
||||
raise NotImplementedError(
|
||||
"scalar prefetch not implemented in the Triton backend"
|
||||
)
|
||||
if mesh is not None:
|
||||
raise NotImplementedError("mesh is not supported in the Triton backend")
|
||||
triton_params = compiler_params.get("triton", compiler_params)
|
||||
num_warps = triton_params.get("num_warps", 4)
|
||||
num_warps = 4 if num_warps is None else num_warps
|
||||
|
@ -670,8 +670,8 @@ def choice(key: ArrayLike,
|
||||
ind = jnp.searchsorted(p_cuml, r).astype(int)
|
||||
else:
|
||||
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
|
||||
ind = jnp.argsort(g)[:n_draws]
|
||||
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
|
||||
ind = lax.top_k(g, k=n_draws)[1].astype(int)
|
||||
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
||||
|
||||
return result.reshape(shape if arr.ndim == 0 else
|
||||
@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
|
||||
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
||||
|
||||
|
||||
def categorical(key: ArrayLike,
|
||||
logits: RealArray,
|
||||
axis: int = -1,
|
||||
shape: Shape | None = None) -> Array:
|
||||
def categorical(
|
||||
key: ArrayLike,
|
||||
logits: RealArray,
|
||||
axis: int = -1,
|
||||
shape: Shape | None = None,
|
||||
replace: bool = True,
|
||||
) -> Array:
|
||||
"""Sample random values from categorical distributions.
|
||||
|
||||
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
|
||||
the Gumbel top-k trick. See [1] for reference.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
||||
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
|
||||
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
||||
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
||||
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
||||
replace: If True, perform sampling without replacement. Default (False) is to
|
||||
perform sampling with replacement.
|
||||
|
||||
Returns:
|
||||
A random array with int dtype and shape given by ``shape`` if ``shape``
|
||||
is not None, or else ``np.delete(logits.shape, axis)``.
|
||||
|
||||
References:
|
||||
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
|
||||
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
|
||||
Proceedings of the 36th International Conference on Machine Learning, PMLR
|
||||
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
|
||||
"""
|
||||
key, _ = _check_prng_key("categorical", key)
|
||||
check_arraylike("categorical", logits)
|
||||
logits_arr = jnp.asarray(logits)
|
||||
|
||||
if axis >= 0:
|
||||
axis -= len(logits_arr.shape)
|
||||
|
||||
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
||||
if shape is None:
|
||||
shape = batch_shape
|
||||
else:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
_check_shape("categorical", shape, batch_shape)
|
||||
|
||||
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
||||
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
||||
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
||||
return jnp.argmax(
|
||||
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
||||
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
||||
axis=axis)
|
||||
|
||||
if replace:
|
||||
if axis >= 0:
|
||||
axis -= len(logits_arr.shape)
|
||||
|
||||
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
||||
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
||||
return jnp.argmax(
|
||||
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
||||
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
||||
axis=axis)
|
||||
else:
|
||||
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
|
||||
k = math.prod(shape_prefix)
|
||||
if k > logits_arr.shape[axis]:
|
||||
raise ValueError(
|
||||
f"Number of samples without replacement ({k}) cannot exceed number of "
|
||||
f"categories ({logits_arr.shape[axis]})."
|
||||
)
|
||||
|
||||
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
|
||||
assert indices.shape == batch_shape + (k,)
|
||||
assert shape == shape_prefix + batch_shape
|
||||
|
||||
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
|
||||
indices = lax.reshape(indices, shape, dimensions)
|
||||
assert indices.shape == shape
|
||||
return indices
|
||||
|
||||
|
||||
def laplace(key: ArrayLike,
|
||||
|
@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
|
||||
return sdy_sharding
|
||||
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def get_replicated_hlo_sharding():
|
||||
return xc.HloSharding.replicate()
|
||||
replicated_hlo_sharding = xc.HloSharding.replicate()
|
||||
|
||||
|
||||
@use_cpp_class(xc.SingleDeviceSharding)
|
||||
@ -183,7 +181,7 @@ class SingleDeviceSharding(jsharding.Sharding):
|
||||
return (self._device,)
|
||||
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return get_replicated_hlo_sharding()
|
||||
return replicated_hlo_sharding
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
|
||||
@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding(
|
||||
def _positional_sharding_to_xla_hlo_sharding(
|
||||
self, num_dimensions: int) -> xc.HloSharding:
|
||||
if self.shape == (1,) * self.ndim:
|
||||
return get_replicated_hlo_sharding()
|
||||
return replicated_hlo_sharding
|
||||
|
||||
pbuf = xc.OpSharding()
|
||||
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
|
||||
@ -603,7 +601,7 @@ class GSPMDSharding(jsharding.Sharding):
|
||||
@functools.cached_property
|
||||
def _hlo_sharding_hash(self):
|
||||
if self.is_fully_replicated:
|
||||
return hash(get_replicated_hlo_sharding())
|
||||
return hash(replicated_hlo_sharding)
|
||||
return hash(self._hlo_sharding)
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -669,7 +667,7 @@ class GSPMDSharding(jsharding.Sharding):
|
||||
|
||||
@classmethod
|
||||
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
|
||||
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
|
||||
return cls(tuple(device_assignment), replicated_hlo_sharding,
|
||||
memory_kind=memory_kind)
|
||||
|
||||
|
||||
|
@ -244,52 +244,62 @@ def curry(f):
|
||||
"""
|
||||
return wraps(f)(partial(partial, f))
|
||||
|
||||
def toposort(end_nodes):
|
||||
if not end_nodes: return []
|
||||
end_nodes = _remove_duplicates(end_nodes)
|
||||
# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
|
||||
toposort: Callable[[Iterable[Any]], list[Any]]
|
||||
if hasattr(jaxlib_utils, "topological_sort"):
|
||||
toposort = partial(jaxlib_utils.topological_sort, "parents")
|
||||
else:
|
||||
|
||||
child_counts = {}
|
||||
stack = list(end_nodes)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(node.parents)
|
||||
for node in end_nodes:
|
||||
child_counts[id(node)] -= 1
|
||||
def toposort(end_nodes):
|
||||
if not end_nodes:
|
||||
return []
|
||||
end_nodes = _remove_duplicates(end_nodes)
|
||||
|
||||
sorted_nodes = []
|
||||
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
|
||||
assert childless_nodes
|
||||
while childless_nodes:
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in node.parents:
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
child_counts = {}
|
||||
stack = list(end_nodes)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[id(parent)] -= 1
|
||||
sorted_nodes = sorted_nodes[::-1]
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(node.parents)
|
||||
for node in end_nodes:
|
||||
child_counts[id(node)] -= 1
|
||||
|
||||
check_toposort(sorted_nodes)
|
||||
return sorted_nodes
|
||||
sorted_nodes = []
|
||||
childless_nodes = [
|
||||
node for node in end_nodes if child_counts[id(node)] == 0
|
||||
]
|
||||
assert childless_nodes
|
||||
while childless_nodes:
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in node.parents:
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
else:
|
||||
child_counts[id(parent)] -= 1
|
||||
sorted_nodes = sorted_nodes[::-1]
|
||||
|
||||
def check_toposort(nodes):
|
||||
visited = set()
|
||||
for node in nodes:
|
||||
assert all(id(parent) in visited for parent in node.parents)
|
||||
visited.add(id(node))
|
||||
check_toposort(sorted_nodes)
|
||||
return sorted_nodes
|
||||
|
||||
def check_toposort(nodes):
|
||||
visited = set()
|
||||
for node in nodes:
|
||||
assert all(id(parent) in visited for parent in node.parents)
|
||||
visited.add(id(node))
|
||||
|
||||
def _remove_duplicates(node_list):
|
||||
seen = set()
|
||||
out = []
|
||||
for n in node_list:
|
||||
if id(n) not in seen:
|
||||
seen.add(id(n))
|
||||
out.append(n)
|
||||
return out
|
||||
|
||||
def _remove_duplicates(node_list):
|
||||
seen = set()
|
||||
out = []
|
||||
for n in node_list:
|
||||
if id(n) not in seen:
|
||||
seen.add(id(n))
|
||||
out.append(n)
|
||||
return out
|
||||
|
||||
def split_merge(predicate, xs):
|
||||
sides = list(map(predicate, xs))
|
||||
@ -658,17 +668,12 @@ def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]:
|
||||
|
||||
exclude_methods = {'__module__', '__dict__', '__doc__'}
|
||||
|
||||
originals = {}
|
||||
for attr_name, attr in cls.__dict__.items():
|
||||
if attr_name not in exclude_methods:
|
||||
if hasattr(_original_func(attr), "_use_cpp"):
|
||||
originals[attr_name] = attr
|
||||
else:
|
||||
if not hasattr(_original_func(attr), "_use_cpp"):
|
||||
setattr(cpp_cls, attr_name, attr)
|
||||
|
||||
cpp_cls.__doc__ = cls.__doc__
|
||||
# TODO(pschuh): Remove once fastpath is gone.
|
||||
cpp_cls._original_py_fns = originals
|
||||
return cpp_cls
|
||||
|
||||
return wrapper
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for serialization and deserialization of GDA."""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for mnist_lib, saved_model_lib, saved_model_main."""
|
||||
|
||||
import os
|
||||
from absl import flags
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for call_tf."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import contextlib
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the jax2tf conversion for control-flow primitives."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the shape-polymorphic jax2tf conversion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -320,6 +320,20 @@ def _vector_splat_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
|
||||
|
||||
|
||||
@_register_lowering(vector.ShapeCastOp)
|
||||
def _vector_shape_cast_op_lowering_rule(
|
||||
_: LoweringContext, op: vector.ShapeCastOp
|
||||
) -> Sequence[ir.Value]:
|
||||
[layout] = inference_utils.in_layouts(op)
|
||||
out_vec_ty = ir.VectorType(op.result.type)
|
||||
assert out_vec_ty.has_static_shape
|
||||
is_signed = (
|
||||
False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None
|
||||
)
|
||||
a = _fragmented_array_from_ir(op.source, layout, is_signed)
|
||||
return [_fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)]
|
||||
|
||||
|
||||
@_register_lowering(vector.ReductionOp)
|
||||
def _vector_reduction_op_lowering_rule(
|
||||
ctx: LoweringContext, op: vector.ReductionOp
|
||||
|
@ -382,21 +382,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]):
|
||||
return WGMMA_LAYOUT
|
||||
|
||||
|
||||
def _tiled_wgmma_layout_for_upcast(shape: tuple[int, ...]):
|
||||
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Shape {shape} is not 2D")
|
||||
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
|
||||
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
||||
t = Tiling(((64, 16), (16, 16), (8, 16), (4,), (2, 1)))
|
||||
return TiledLayout(
|
||||
t,
|
||||
warp_dim=-9,
|
||||
lane_dims=(-5, -2, -4),
|
||||
vector_dim=-3,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMARowFragLayout:
|
||||
"""[m] matrix, where m % 64 == 0."""
|
||||
@ -505,13 +490,55 @@ WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
|
||||
|
||||
# The tiled layout is equivalent to one described here in PTX documentation:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
|
||||
# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles.
|
||||
# Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit
|
||||
# of data that is split across a warp. Since 8*8 = 64, but a warp has only 32
|
||||
# threads, we vectorize pairs of elements along columns.
|
||||
# The assignment of elements to warp lanes is as follows:
|
||||
#
|
||||
# 0 0 1 1 2 2 3 3
|
||||
# 4 4 5 5 6 6 7 7
|
||||
# 8 8 9 9 10 10 11 11
|
||||
# 12 12 13 13 14 14 15 15
|
||||
# ...
|
||||
WGMMA_LAYOUT = TiledLayout(
|
||||
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
|
||||
warp_dim=-8,
|
||||
lane_dims=(-4, -3),
|
||||
vector_dim=-1,
|
||||
)
|
||||
# This tiled layout is similar to the one above. Above, each warp stores a 8x8
|
||||
# This tiled layout is similar to the WGMMA layout, only the unit at which we
|
||||
# assign submatrices to warps grows from 8x8 to 8x16. The elements within each
|
||||
# submatrix are assigned to threads in the following way:
|
||||
#
|
||||
# 0 0 0 0 2 2 2 2 1 1 1 1 3 3 3 3
|
||||
# 4 4 4 4 6 6 6 6 5 5 5 5 7 7 7 7
|
||||
# ...
|
||||
#
|
||||
# Our vector length is twice the size of that of WGMMA_LAYOUT, which lets us use
|
||||
# 32-bit SMEM loads/stores when dealing with 8-bit values. The conversion
|
||||
# to the WGMMA layout only requires communication between with index differing
|
||||
# in their 2 bit (i.e. 0 and 1, 2 and 4), so the conversion to WGMMA_LAYOUT
|
||||
# only requires a single warp shuffle (plus permutes local to each thread).
|
||||
WGMMA_LAYOUT_UPCAST_2X = TiledLayout(
|
||||
Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))),
|
||||
warp_dim=-8,
|
||||
lane_dims=(-4, -2, -3),
|
||||
vector_dim=-1,
|
||||
)
|
||||
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
|
||||
# purpose of passing them into WGMMA later. The core matrices stored by a warp
|
||||
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
|
||||
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
|
||||
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
|
||||
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
|
||||
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
|
||||
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
|
||||
warp_dim=-7,
|
||||
lane_dims=(-3, -2),
|
||||
vector_dim=-1,
|
||||
)
|
||||
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
|
||||
# submatrix in the following way (we only show the first 4 rows for brevity):
|
||||
#
|
||||
# 0 0 1 1 2 2 3 3
|
||||
@ -697,6 +724,7 @@ class FragmentedArray:
|
||||
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
|
||||
"""
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
c = lambda x: arith.constant(i32, x)
|
||||
if self.layout == new_layout:
|
||||
return self
|
||||
shape = self.shape
|
||||
@ -707,24 +735,148 @@ class FragmentedArray:
|
||||
):
|
||||
is_even_row = arith.cmpi(
|
||||
arith.CmpIPredicate.eq,
|
||||
arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)),
|
||||
c(0, i32),
|
||||
arith.remui(arith.divui(utils.thread_idx(), c(4)), c(2)),
|
||||
c(0),
|
||||
)
|
||||
perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32))
|
||||
perm = arith.select(is_even_row, c(0x5410), c(0x3276))
|
||||
new_regs = []
|
||||
for reg in self.registers.flat:
|
||||
reg_ty = reg.type
|
||||
reg = utils.bitcast(reg, i32)
|
||||
reg_shfl = utils.shfl_bfly(reg, 4)
|
||||
new_reg = llvm.inline_asm(
|
||||
i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
|
||||
)
|
||||
new_reg = utils.prmt(reg, reg_shfl, perm)
|
||||
new_regs.append(utils.bitcast(new_reg, reg_ty))
|
||||
return FragmentedArray(
|
||||
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
|
||||
_layout=new_layout,
|
||||
_is_signed=self.is_signed,
|
||||
)
|
||||
if (
|
||||
self.layout == WGMMA_LAYOUT_UPCAST_2X
|
||||
and new_layout == WGMMA_LAYOUT
|
||||
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
|
||||
):
|
||||
assert shape[1] % 16 == 0 # Should be implied by the layout
|
||||
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||
is_even = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
|
||||
)
|
||||
registers = self.registers
|
||||
if dtype_bitwidth == 4:
|
||||
if registers.shape[1] % 2:
|
||||
raise NotImplementedError(
|
||||
"This relayout implementation requires an even number of column"
|
||||
" tiles (to pack pairs of them for efficiency)"
|
||||
)
|
||||
# We pair up the consecutive column tiles, so each register is 32-bit.
|
||||
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
|
||||
# LLVM will realize that the paired up vectors actually came from the
|
||||
# same 32-bit register and it will become a no-op.
|
||||
col_minor_registers = np.moveaxis(registers, 1, -1)
|
||||
flat_registers = [
|
||||
utils.vector_concat((l, h))
|
||||
for l, h in zip(
|
||||
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
|
||||
)
|
||||
]
|
||||
registers = np.asarray(flat_registers, dtype=object).reshape(
|
||||
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
|
||||
)
|
||||
registers = np.moveaxis(registers, -1, 1)
|
||||
for idx, reg in np.ndenumerate(registers):
|
||||
if dtype_bitwidth == 16:
|
||||
assert reg.type.shape == [4]
|
||||
# A single vector is 64-bits, but shuffles are only 32-bit wide.
|
||||
# We only shuffle the half that needs to go to other thread.
|
||||
low = utils.vector_slice(reg, slice(0, 2))
|
||||
high = utils.vector_slice(reg, slice(2, 4))
|
||||
to_exchange = arith.select(is_even, high, low)
|
||||
# Exchange values between even and odd threads.
|
||||
exchanged = utils.shfl_bfly(to_exchange, 1)
|
||||
low = arith.select(is_even, low, exchanged)
|
||||
high = arith.select(is_even, exchanged, high)
|
||||
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
|
||||
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
|
||||
elif dtype_bitwidth == 8:
|
||||
assert reg.type.shape == [4]
|
||||
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||
# use prmt to blend it with the local register.
|
||||
exchanged = utils.shfl_bfly(reg, 1)
|
||||
# Consider lanes 0 and 1, because the situation is symmetric for
|
||||
# each pair. If we feed reg[lane] and exchanged[lane] (which is
|
||||
# really the same as reg of the other lane) to prmt, we can index
|
||||
# the elements of the result using the following indices:
|
||||
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
|
||||
# prmt[0]: 0 1 2 3 4 5 6 7
|
||||
# prmt[1]: 4 5 6 7 0 1 2 3
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
|
||||
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
|
||||
# Note that the patterns still need to be flipped, since we listed
|
||||
# bytes with LSB on the left, which is the opposite of how the
|
||||
# numeric constants are spelled in Python (LSB on the right).
|
||||
perm = arith.select(is_even, c(0x5410), c(0x3276))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
for i in range(2):
|
||||
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
||||
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
||||
else:
|
||||
assert dtype_bitwidth == 4
|
||||
assert reg.type.shape == [8] # We paired up the registers above.
|
||||
exchanged = utils.shfl_bfly(reg, 1)
|
||||
# See comment above for a more complete explanation.
|
||||
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
|
||||
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
|
||||
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
|
||||
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
|
||||
perm = arith.select(is_even, c(0x6240), c(0x3715))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
for i in range(4):
|
||||
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
||||
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
|
||||
assert all(r is not None for r in new_registers)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||
)
|
||||
if (
|
||||
self.layout == WGMMA_LAYOUT_UPCAST_4X
|
||||
and new_layout == WGMMA_LAYOUT_UPCAST_2X
|
||||
and utils.bitwidth(self.mlir_dtype) == 4
|
||||
):
|
||||
assert shape[0] % 64 == 0 # Should be implied by the layout
|
||||
assert shape[1] % 32 == 0 # Should be implied by the layout
|
||||
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
c = lambda x: arith.constant(i32, x)
|
||||
is_01 = arith.cmpi(
|
||||
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
|
||||
)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
assert ir.VectorType(reg.type).shape == [8]
|
||||
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||
# use prmt to blend it with the local register.
|
||||
exchanged = utils.shfl_bfly(reg, 2)
|
||||
# See comments above for conventions. Here we exchange data between
|
||||
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
|
||||
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
|
||||
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
|
||||
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
|
||||
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
|
||||
perm = arith.select(is_01, c(0x5410), c(0x3276))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
for i in range(2):
|
||||
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
|
||||
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
||||
assert all(r is not None for r in new_registers)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||
)
|
||||
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
|
||||
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
|
||||
if not isinstance(self.layout, WGSplatFragLayout):
|
||||
raise NotImplementedError(
|
||||
f"Cannot convert from {self.layout} to {new_layout}"
|
||||
@ -1178,11 +1330,15 @@ class FragmentedArray:
|
||||
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
||||
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
||||
[vector_len] = reg_shape # This is meant to be a 1D assertion.
|
||||
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
|
||||
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
|
||||
raise ValueError(
|
||||
"Register bitwidth in target type must be divisible by 8, got"
|
||||
f" {new_reg_bitwidth}"
|
||||
)
|
||||
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
|
||||
new_registers = np.empty_like(self.registers)
|
||||
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
|
||||
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
|
||||
# The algorithm here is largely the same as CUTLASS's
|
||||
# NumericArrayConverter specialization for int4 -> bf16 casts.
|
||||
# We modify it slightly, because we only extract 2 values.
|
||||
@ -1196,25 +1352,58 @@ class FragmentedArray:
|
||||
# positive int4s will end up larger than negative int4s, with a bias of
|
||||
# 8. Use use the sub to subtract the base (our initial exponent) and the
|
||||
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
|
||||
new_reg_32 = llvm.inline_asm(
|
||||
i32,
|
||||
[reg_8],
|
||||
"""
|
||||
{
|
||||
.reg .b32 s<4>;
|
||||
shr.s32 s0, $1, 4;
|
||||
prmt.b32 s1, $1, s0, 0xF4F0;
|
||||
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
||||
mov.b32 s3, 0x43084308;
|
||||
sub.bf16x2 $0, s2, s3;
|
||||
}
|
||||
""",
|
||||
"=r,r",
|
||||
)
|
||||
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
|
||||
new_registers[idx] = vector.bitcast(
|
||||
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
|
||||
)
|
||||
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
|
||||
assert 0 <= part < 4
|
||||
return llvm.inline_asm(
|
||||
i32,
|
||||
[reg, reg_shr],
|
||||
f"""
|
||||
{{
|
||||
.reg .b32 s<4>;
|
||||
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
|
||||
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
||||
mov.b32 s3, 0x43084308;
|
||||
sub.bf16x2 $0, s2, s3;
|
||||
}}
|
||||
""",
|
||||
"=r,r,r",
|
||||
)
|
||||
offset = 0
|
||||
out_int_regs = []
|
||||
for group_size in (8, 4, 2):
|
||||
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
||||
while vector_len - offset >= group_size:
|
||||
# If the vector originates from a slice (common after relayouts), we
|
||||
# can fuse the slicing into the conversion and prevent LLVM from
|
||||
# generating a bunch of shifts to align the vector data to the LSB.
|
||||
# This also lets us share the right shift among more vectors.
|
||||
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
|
||||
and utils.bitwidth(slice_op.vector.type) == 32
|
||||
and slice_op.strides[0].value == 1):
|
||||
slice_offset = slice_op.offsets[0].value + offset
|
||||
reg_int = utils.bitcast(slice_op.vector, i32)
|
||||
reg_int_shr = arith.shrui(reg_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
|
||||
for part in range(group_size // 2)
|
||||
)
|
||||
else:
|
||||
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
||||
if int_ty != i32:
|
||||
reg_slice_int = arith.extsi(i32, reg_slice_int)
|
||||
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
||||
for part in range(group_size // 2)
|
||||
)
|
||||
offset += group_size
|
||||
assert offset == vector_len
|
||||
out_vec_int = utils.vector_concat([
|
||||
vector.splat(ir.VectorType.get((1,), i32), reg)
|
||||
for reg in out_int_regs
|
||||
])
|
||||
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=self.layout, _is_signed=None
|
||||
)
|
||||
@ -1263,11 +1452,6 @@ class FragmentedArray:
|
||||
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
|
||||
)
|
||||
# Generic path.
|
||||
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
|
||||
# same endianness as the target machine (which is little for NVIDIA GPUs).
|
||||
# We'll need to add specialized casting routines that flip the endianness.
|
||||
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
|
||||
raise NotImplementedError("Conversion involving sub-byte types unsupported")
|
||||
from_float = ir.FloatType.isinstance(cur_dtype)
|
||||
to_float = ir.FloatType.isinstance(new_dtype)
|
||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||
@ -1472,17 +1656,17 @@ class FragmentedArray:
|
||||
def reshape(self, shape):
|
||||
if self.shape == shape:
|
||||
return self
|
||||
|
||||
if not isinstance(self.layout, WGSplatFragLayout):
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
if np.prod(shape) != np.prod(self.shape):
|
||||
if math.prod(shape) != math.prod(self.shape):
|
||||
raise ValueError(f"Can't reshape {self.shape} to {shape}")
|
||||
|
||||
match self.layout:
|
||||
case WGSplatFragLayout() | WGStridedFragLayout():
|
||||
new_layout = dataclasses.replace(self.layout, shape=shape)
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
return FragmentedArray(
|
||||
_registers=self.registers,
|
||||
_layout=WGSplatFragLayout(shape),
|
||||
_is_signed=self.is_signed,
|
||||
_registers=self.registers, _layout=new_layout, _is_signed=self.is_signed
|
||||
)
|
||||
|
||||
def broadcast_minor(self, n):
|
||||
|
@ -336,6 +336,37 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
|
||||
|
||||
return [], [layout]
|
||||
|
||||
|
||||
def _update_layout_shape(
|
||||
layout: ir.Attribute, shape: Sequence[int], origin: str
|
||||
) -> ir.Attribute:
|
||||
if layouts_lib.is_splat_fragmented_layout(
|
||||
layout
|
||||
) or layouts_lib.is_strided_fragmented_layout(layout):
|
||||
return layouts_lib.to_layout_attr(
|
||||
dataclasses.replace(layouts_lib.from_layout_attr(layout), shape=shape)
|
||||
)
|
||||
raise NotImplementedError(f"Unsupported {origin} layout: {layout}.")
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, vector.ShapeCastOp)
|
||||
def _infer_shape_cast_op_layout(op: vector.ShapeCastOp) -> OptionalLayouts:
|
||||
in_layout = inference_utils.value_layout(op.source)
|
||||
if in_layout is None:
|
||||
out_layout = inference_utils.value_layout(op.result)
|
||||
if out_layout is None:
|
||||
return None
|
||||
in_layout = _update_layout_shape(
|
||||
out_layout, ir.VectorType(op.source.type).shape, "source"
|
||||
)
|
||||
return [in_layout], [out_layout]
|
||||
|
||||
out_layout = _update_layout_shape(
|
||||
in_layout, ir.VectorType(op.result.type).shape, "result"
|
||||
)
|
||||
return [in_layout], [out_layout]
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, vector.ReductionOp)
|
||||
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
|
||||
if layout := inference_utils.value_layout(op.vector):
|
||||
|
@ -83,6 +83,8 @@ def mma(
|
||||
accumulate: ir.Value | bool = True,
|
||||
collective: bool = False,
|
||||
):
|
||||
if a_swizzle == 16 or b_swizzle == 16:
|
||||
raise NotImplementedError("No swizzle is not supported")
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
if isinstance(accumulate, bool):
|
||||
|
@ -25,8 +25,12 @@ from typing import cast
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from . import fragmented_array as fa
|
||||
from . import inference_utils
|
||||
from . import layouts as layouts_lib
|
||||
from . import utils
|
||||
|
||||
# mypy: ignore-errors
|
||||
@ -39,7 +43,9 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {}
|
||||
def _add_transform_inference_rule(
|
||||
op: type[ir.OpView], rule: TransformInferenceRule
|
||||
):
|
||||
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
||||
if op is not None:
|
||||
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
||||
return rule
|
||||
|
||||
|
||||
def _set_transform_attributes(
|
||||
@ -110,6 +116,86 @@ def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms:
|
||||
return None if in_transforms is None else ([in_transforms], [])
|
||||
|
||||
|
||||
@partial(_add_transform_inference_rule, vector.LoadOp)
|
||||
@partial(_add_transform_inference_rule, vector.StoreOp)
|
||||
def _infer_vector_load_store_transforms(
|
||||
op: vector.LoadOp | vector.StoreOp,
|
||||
) -> OptionalTransforms:
|
||||
for i in op.indices:
|
||||
index_defining_op = i.owner.opview
|
||||
if (
|
||||
not isinstance(index_defining_op, arith.ConstantOp)
|
||||
or index_defining_op.literal_value != 0
|
||||
):
|
||||
# TODO(bchetioui): handle slicing.
|
||||
raise NotImplementedError(
|
||||
f"Only constants with value 0 are supported as indices for {op}"
|
||||
)
|
||||
|
||||
if isinstance(op, vector.LoadOp):
|
||||
[layout_attr] = inference_utils.out_layouts(op)
|
||||
else:
|
||||
assert isinstance(op, vector.StoreOp)
|
||||
[layout_attr] = inference_utils.in_layouts(op)
|
||||
|
||||
layout = layouts_lib.from_layout_attr(layout_attr)
|
||||
transforms = inference_utils.value_transforms(op.base)
|
||||
|
||||
if layout == fa.WGMMA_LAYOUT:
|
||||
layout_transforms = infer_transforms_for_wgmma_ref(
|
||||
ir.MemRefType(op.base.type)
|
||||
)
|
||||
elif (isinstance(layout, fa.WGStridedFragLayout) or
|
||||
isinstance(layout, fa.WGSplatFragLayout)):
|
||||
layout_transforms = None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Got layout {layout} which is not yet supported"
|
||||
)
|
||||
|
||||
if transforms is not None and layout_transforms is not None:
|
||||
if transforms != layout_transforms:
|
||||
raise NotImplementedError(
|
||||
f"Conflicting transforms for {op.base} in {op}: "
|
||||
f"{transforms} != {layout_transforms}."
|
||||
)
|
||||
return [transforms], []
|
||||
|
||||
if transforms is not None:
|
||||
return [transforms], []
|
||||
|
||||
if layout_transforms is not None:
|
||||
return [layout_transforms], []
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
||||
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
||||
|
||||
@partial(_add_transform_inference_rule, SliceSMEMOp)
|
||||
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
|
||||
transforms = None
|
||||
uses = cast(ir.OpResult, op.result).uses
|
||||
|
||||
for op_operand_use in uses:
|
||||
consumer = op_operand_use.owner
|
||||
op_user = consumer.operands[op_operand_use.operand_number]
|
||||
out_transforms = inference_utils.in_transforms_for_operand(
|
||||
consumer, op_user
|
||||
)
|
||||
if transforms is not None and out_transforms is not None:
|
||||
if transforms != out_transforms:
|
||||
raise NotImplementedError(
|
||||
f"Conflicting transforms for {op_user} in {op}: "
|
||||
f"{transforms} != {out_transforms}."
|
||||
)
|
||||
elif out_transforms is not None:
|
||||
transforms = out_transforms
|
||||
|
||||
return None if transforms is None else ([], [transforms])
|
||||
|
||||
|
||||
def _should_have_transforms(op: ir.OpView) -> bool:
|
||||
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
||||
return any(
|
||||
|
@ -346,8 +346,11 @@ def bitwidth_impl(ty: ir.Type):
|
||||
return ir.IntegerType(ty).width
|
||||
if ir.FloatType.isinstance(ty):
|
||||
return ir.FloatType(ty).width
|
||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
return MBARRIER_BYTES * 8
|
||||
if ir.VectorType.isinstance(ty):
|
||||
vty = ir.VectorType(ty)
|
||||
return math.prod(vty.shape) * bitwidth(vty.element_type)
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
|
||||
@ -1180,13 +1183,33 @@ def shfl_bfly(x: ir.Value, distance: int | ir.Value):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if isinstance(distance, int):
|
||||
distance = c(distance, i32)
|
||||
assert x.type == i32
|
||||
return nvvm.shfl_sync(
|
||||
if (result_type := x.type) != i32:
|
||||
x = bitcast(x, i32)
|
||||
y = nvvm.shfl_sync(
|
||||
i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly,
|
||||
)
|
||||
return bitcast(y, result_type)
|
||||
|
||||
|
||||
def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if (result_type := high.type) != low.type:
|
||||
raise ValueError(f"Types must match, got {high.type} and {low.type}")
|
||||
if high.type != i32:
|
||||
high = bitcast(high, i32)
|
||||
if low.type != i32:
|
||||
low = bitcast(low, i32)
|
||||
if permutation.type != i32:
|
||||
permutation = bitcast(permutation, i32)
|
||||
result = llvm.inline_asm(
|
||||
i32, [high, low, permutation], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
|
||||
)
|
||||
return bitcast(result, result_type)
|
||||
|
||||
|
||||
def bitcast(x: ir.Value, new_type: ir.Type):
|
||||
if x.type == new_type:
|
||||
return x
|
||||
if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type):
|
||||
new_type = ir.IntegerType(new_type)
|
||||
x_ty = ir.VectorType(x.type)
|
||||
@ -1200,8 +1223,50 @@ def bitcast(x: ir.Value, new_type: ir.Type):
|
||||
x_ty = ir.IntegerType(x.type)
|
||||
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
|
||||
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
|
||||
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
|
||||
x_ty = ir.VectorType(x.type)
|
||||
new_ty = ir.VectorType(new_type)
|
||||
if bitwidth(x_ty) != bitwidth(new_ty):
|
||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||
return vector.bitcast(new_type, x)
|
||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def vector_slice(v: ir.Value, s: slice):
|
||||
v_ty = ir.VectorType(v.type)
|
||||
if len(v_ty.shape) != 1:
|
||||
raise NotImplementedError(v_ty)
|
||||
[v_len] = v_ty.shape
|
||||
slice_length = len(range(v_len)[s])
|
||||
return vector.extract_strided_slice(
|
||||
ir.VectorType.get((slice_length,), v_ty.element_type),
|
||||
v, [s.start or 0], [slice_length], [1],
|
||||
)
|
||||
|
||||
|
||||
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
|
||||
index = ir.IndexType.get()
|
||||
if not vectors:
|
||||
raise ValueError("Cannot concatenate an empty list of vectors")
|
||||
vty = vectors[0].type
|
||||
if not ir.VectorType.isinstance(vty):
|
||||
raise ValueError("Cannot concatenate non-vector values")
|
||||
if vty.rank != 1:
|
||||
raise NotImplementedError("Only 1D vectors are supported")
|
||||
for v in vectors:
|
||||
if v.type != vty:
|
||||
raise ValueError("Cannot concatenate vectors of different types")
|
||||
result = llvm.mlir_undef(
|
||||
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
|
||||
)
|
||||
offset = 0
|
||||
for v in vectors:
|
||||
for i in range(vty.shape[0]):
|
||||
elem = vector.extractelement(v, position=c(i, index))
|
||||
result = vector.insertelement(elem, result, position=c(offset + i, index))
|
||||
offset += vty.shape[0]
|
||||
return result
|
||||
|
@ -259,6 +259,8 @@ def wgmma(
|
||||
The refs must be contiguous or be contiguous except for having their two minor
|
||||
dimensions swapped.
|
||||
"""
|
||||
if swizzle == 16:
|
||||
raise NotImplementedError("No swizzle is not supported")
|
||||
# Step 1. Establish the shape and element type of the operation.
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
|
@ -214,6 +214,8 @@ nanobind_extension(
|
||||
module_name = "utils",
|
||||
deps = [
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@nanobind",
|
||||
|
@ -65,6 +65,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -238,11 +238,12 @@ NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
"failed to enable tracking of kernel activity by CUPTI");
|
||||
});
|
||||
m.def("_cupti_get_timings", []() {
|
||||
THROW_IF_CUPTI_ERROR(
|
||||
cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED),
|
||||
"failed to flush CUPTI activity buffers");
|
||||
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
|
||||
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
|
||||
"failed to unsubscribe from CUPTI");
|
||||
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
|
||||
"failed to flush CUPTI activity buffers");
|
||||
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
|
||||
return profiler_state.timings;
|
||||
});
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mosaic/gpu/passes.h"
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -23,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/include/mlir/IR/SymbolTable.h"
|
||||
@ -36,6 +38,49 @@ namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
|
||||
struct ConvertExtractStridedSlicePattern final
|
||||
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
auto vty = op.getSourceVectorType();
|
||||
if (vty.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
|
||||
}
|
||||
int64_t size =
|
||||
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||
if (size < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "size is negative");
|
||||
}
|
||||
int64_t start =
|
||||
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||
int64_t stride =
|
||||
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||
if (stride != 1) {
|
||||
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
|
||||
}
|
||||
if (start < 0 || start + size > vty.getShape()[0]) {
|
||||
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
|
||||
}
|
||||
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
|
||||
op.getLoc(), op.getResult().getType());
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
result = rewriter.create<mlir::LLVM::InsertElementOp>(
|
||||
op.getLoc(), result,
|
||||
rewriter.create<mlir::LLVM::ExtractElementOp>(
|
||||
op.getLoc(), subst.getVector(),
|
||||
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
|
||||
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(i)));
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertGpuToLLVMPass
|
||||
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
|
||||
public:
|
||||
@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
|
||||
});
|
||||
auto symtab = mlir::SymbolTable(getOperation());
|
||||
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
|
||||
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
|
||||
if (mlir::applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))
|
||||
.failed()) {
|
||||
|
@ -16,9 +16,13 @@ limitations under the License.
|
||||
#include <Python.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = {
|
||||
METH_FASTCALL,
|
||||
};
|
||||
|
||||
nb::list TopologicalSort(nb::str parents_attr,
|
||||
nb::iterable end_nodes_iterable) {
|
||||
// This is a direct conversion of the original Python implementation.
|
||||
// More efficient implementations of a topological sort are possible (and
|
||||
// indeed, easier to write), but changing the choice of topological order
|
||||
// would break existing tests.
|
||||
std::vector<nb::object> end_nodes;
|
||||
absl::flat_hash_set<PyObject*> seen;
|
||||
for (nb::handle n : end_nodes_iterable) {
|
||||
nb::object node = nb::borrow(n);
|
||||
if (seen.insert(node.ptr()).second) {
|
||||
end_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
nb::list sorted_nodes;
|
||||
if (end_nodes.empty()) {
|
||||
return sorted_nodes;
|
||||
}
|
||||
|
||||
std::vector<nb::object> stack = end_nodes;
|
||||
absl::flat_hash_map<PyObject*, int> child_counts;
|
||||
while (!stack.empty()) {
|
||||
nb::object node = std::move(stack.back());
|
||||
stack.pop_back();
|
||||
auto& count = child_counts[node.ptr()];
|
||||
if (count == 0) {
|
||||
for (nb::handle parent : node.attr(parents_attr)) {
|
||||
stack.push_back(nb::borrow(parent));
|
||||
}
|
||||
}
|
||||
++count;
|
||||
}
|
||||
|
||||
for (nb::handle n : end_nodes) {
|
||||
child_counts[n.ptr()] -= 1;
|
||||
}
|
||||
|
||||
std::vector<nb::object> childless_nodes;
|
||||
childless_nodes.reserve(end_nodes.size());
|
||||
for (nb::handle n : end_nodes) {
|
||||
if (child_counts[n.ptr()] == 0) {
|
||||
childless_nodes.push_back(nb::borrow(n));
|
||||
}
|
||||
}
|
||||
|
||||
while (!childless_nodes.empty()) {
|
||||
nb::object node = std::move(childless_nodes.back());
|
||||
childless_nodes.pop_back();
|
||||
sorted_nodes.append(node);
|
||||
for (nb::handle parent : node.attr(parents_attr)) {
|
||||
auto& count = child_counts[parent.ptr()];
|
||||
if (count == 1) {
|
||||
childless_nodes.push_back(nb::borrow(parent));
|
||||
} else {
|
||||
--count;
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted_nodes.reverse();
|
||||
return sorted_nodes;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NB_MODULE(utils, m) {
|
||||
@ -304,6 +371,13 @@ NB_MODULE(utils, m) {
|
||||
m.attr("safe_zip") = nb::steal<nb::object>(
|
||||
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
|
||||
|
||||
m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"),
|
||||
nb::arg("end_nodes"),
|
||||
"Computes a topological sort of a graph of objects. parents_attr is "
|
||||
"the name of the attribute on each object that contains the list of "
|
||||
"parent objects. end_nodes is an iterable of objects from which we "
|
||||
"should start a backwards search.");
|
||||
|
||||
// Python has no reader-writer lock in its standard library, so we expose
|
||||
// bindings around absl::Mutex.
|
||||
nb::class_<absl::Mutex>(m, "Mutex")
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for AOT compilation."""
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for jax.api_util."""
|
||||
|
||||
import itertools as it
|
||||
from absl.testing import absltest
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for Array."""
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
|
@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase):
|
||||
self.assertEqual(ans.names, expected.names)
|
||||
self.assertAllClose(ans.data, expected.data)
|
||||
|
||||
def test_types_with_same_spec(self):
|
||||
# We register NamedArray.
|
||||
batching.register_vmappable(NamedArray, NamedMapSpec, int,
|
||||
named_to_elt, named_from_elt, None)
|
||||
|
||||
# We then register another type that uses NamedMapSpec as the spec_type too,
|
||||
# and immediately unregister it.
|
||||
class Foo:
|
||||
pass
|
||||
batching.register_vmappable(Foo, NamedMapSpec, int,
|
||||
named_to_elt, named_from_elt, None)
|
||||
batching.unregister_vmappable(Foo)
|
||||
|
||||
# We should still be able to use vmap on NamedArray.
|
||||
def f(x):
|
||||
return named_mul(x, x)
|
||||
|
||||
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
||||
ans = jax.jit(f)(x)
|
||||
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
||||
|
||||
self.assertEqual(ans.names, expected.names)
|
||||
self.assertAllClose(ans.data, expected.data)
|
||||
|
||||
# And unregister NamedArray without exceptions.
|
||||
batching.unregister_vmappable(NamedArray)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -37,18 +37,41 @@ def call_kernel(
|
||||
m, n = grid
|
||||
return jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
kernel(i, j, *args) for j in range(n)], axis=1)
|
||||
kernel((i, j), *args) for j in range(n)], axis=1)
|
||||
for i in range(m)], axis=0)
|
||||
|
||||
|
||||
def uniform_kernel(i: int, j: int, total_size, block_size, tile_size):
|
||||
"""Uniform random sampling kernel function."""
|
||||
global_key = jax.random.key(0)
|
||||
keys = blocked_sampler.blocked_fold_in(global_key,
|
||||
def call_kernel_3d(
|
||||
kernel,
|
||||
grid: tuple[int, int],
|
||||
*args
|
||||
):
|
||||
"""Calls a kernel over a 3D grid and concatenates results to a single array."""
|
||||
depth, rows, cols = grid
|
||||
return jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
jnp.array(kernel((i, j, k), *args))
|
||||
for k in range(cols)], axis=2)
|
||||
for j in range(rows)], axis=1)
|
||||
for i in range(depth)], axis=0)
|
||||
|
||||
|
||||
def blocked_fold_in(block_index, key, total_size, block_size, tile_size):
|
||||
"""Folds in block_index into global_key."""
|
||||
return blocked_sampler.blocked_fold_in(key,
|
||||
total_size=total_size,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size,
|
||||
block_index=(i, j))
|
||||
block_index=block_index)
|
||||
|
||||
|
||||
def uniform_kernel(block_index, key, total_size, block_size, tile_size):
|
||||
"""Uniform random sampling kernel function."""
|
||||
keys = blocked_fold_in(block_index, key,
|
||||
total_size=total_size,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size)
|
||||
return blocked_sampler.sample_block(jax.random.uniform,
|
||||
keys,
|
||||
block_size=block_size,
|
||||
@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_block_shape_invariance(self, total_size, block_size_a,
|
||||
block_size_b, tile_size, transpose_grid):
|
||||
global_key = jax.random.key(0)
|
||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||
result_a = call_kernel(
|
||||
uniform_kernel, grid_a, transpose_grid,
|
||||
uniform_kernel, grid_a, transpose_grid, global_key,
|
||||
total_size, block_size_a, tile_size)
|
||||
|
||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||
result_b = call_kernel(
|
||||
uniform_kernel, grid_b, transpose_grid,
|
||||
uniform_kernel, grid_b, transpose_grid, global_key,
|
||||
total_size, block_size_b, tile_size)
|
||||
np.testing.assert_array_equal(result_a, result_b)
|
||||
|
||||
|
||||
class BlockedFoldInTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(
|
||||
# Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
|
||||
# as expected. Specifically, blocked key folding does not depend on the total
|
||||
# size of the tensor, but only the total number of tiles.
|
||||
# Using a 3D grid (with very large inner dimensions) triggers an overflow in a
|
||||
# previous implementation of blocked_fold_in.
|
||||
dict(testcase_name='4096x512_vs_1024x2048',
|
||||
total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512),
|
||||
block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)),
|
||||
)
|
||||
def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a,
|
||||
block_size_b, tile_size):
|
||||
global_key = jax.random.key(0)
|
||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||
result_a = call_kernel_3d(
|
||||
blocked_fold_in, grid_a, global_key, total_size,
|
||||
block_size_a, tile_size)
|
||||
|
||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||
result_b = call_kernel_3d(
|
||||
blocked_fold_in, grid_b, global_key, total_size,
|
||||
block_size_b, tile_size)
|
||||
np.testing.assert_array_equal(jax.random.key_data(result_a),
|
||||
jax.random.key_data(result_b))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for release_backend_clients."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for --debug_nans."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
|
@ -20,12 +20,14 @@ from jax._src import config
|
||||
from jax._src import error_check
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
|
||||
|
||||
JaxValueError = error_check.JaxValueError
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(4)
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
@ -190,6 +192,23 @@ class ErrorCheckTests(jtu.JaxTestCase):
|
||||
):
|
||||
jax.jit(error_check.raise_if_error)()
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
@jtu.with_user_mesh((2, 2), ("x", "y"))
|
||||
def test_error_check_explicit_mode(self, mesh, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
sharding = NamedSharding(mesh, P("x", "y"))
|
||||
x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding)
|
||||
with error_check.error_checking_context():
|
||||
f(x)
|
||||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
|
||||
error_check.raise_if_error()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for garbage allocation guard."""
|
||||
|
||||
import gc
|
||||
import weakref
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for jax.numpy.ufunc and its methods."""
|
||||
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
|
@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase):
|
||||
x = lax.optimization_barrier((2, 3))
|
||||
self.assertEqual((2, 3), x)
|
||||
|
||||
def test_optimization_barrier_autodiff(self):
|
||||
def f(x):
|
||||
y = 1. * x
|
||||
x, y = lax.optimization_barrier((x, y))
|
||||
z = 2. * x
|
||||
return y + z
|
||||
g = jax.grad(f)(5.) # doesn't crash
|
||||
self.assertAllClose(g, 3., check_dtypes=False)
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the LAPAX linear algebra module."""
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
from typing import Iterator
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for mesh utils."""
|
||||
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
|
@ -74,6 +74,37 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
||||
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
||||
|
||||
def test_infer_strided_layout_from_shape_cast(self):
|
||||
shape = (16, 8)
|
||||
elt_type = ir.BF16Type.get()
|
||||
src_type = ir.VectorType.get(shape, elt_type)
|
||||
dst_type = ir.VectorType.get([*reversed(shape)], elt_type)
|
||||
op = None
|
||||
|
||||
def body(x):
|
||||
nonlocal op
|
||||
op = vector.ShapeCastOp(dst_type, x)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(src_type)(body)
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
in_layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(src_type)
|
||||
)
|
||||
out_layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(dst_type)
|
||||
)
|
||||
|
||||
self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout])
|
||||
self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout])
|
||||
|
||||
# Ensure that we can recover the original layout.
|
||||
del op.attributes["in_layouts"]
|
||||
mgpu.infer_layout(self.module)
|
||||
self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout])
|
||||
|
||||
def test_infer_splat_layout_for_splat_constants(self):
|
||||
shape = (16, 8)
|
||||
elt_type = ir.BF16Type.get()
|
||||
|
@ -12,9 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Mosaic GPU DSL functions and utilities."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
@ -84,6 +84,20 @@ def mlir_sum(elems):
|
||||
return total
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_sass():
|
||||
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
||||
try:
|
||||
with jtu.capture_stdout() as output:
|
||||
yield output
|
||||
finally:
|
||||
if prev_dump is not None:
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
||||
else:
|
||||
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
||||
|
||||
|
||||
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
||||
index = ir.IndexType.get()
|
||||
thread_id = gpu.thread_id(gpu.Dimension.x)
|
||||
@ -519,14 +533,38 @@ class WGMMALayoutTest(TestCase):
|
||||
)()
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("bf16_i8", jnp.bfloat16, jnp.int8),
|
||||
("i8_bf16", jnp.int8, jnp.bfloat16),
|
||||
("i8_i8", jnp.int8, jnp.int8),
|
||||
("i4_i4", jnp.int4, jnp.int4),
|
||||
("i4_bf16", jnp.int4, jnp.bfloat16),
|
||||
@parameterized.parameters(jnp.int8, jnp.int16, jnp.int32)
|
||||
def test_sub_byte_conversion(self, jax_dtype_to):
|
||||
jax_dtype_from = jnp.int4
|
||||
def kernel(ctx, inp, out, smem):
|
||||
del ctx # Unused.
|
||||
smem_inp, smem_out = smem
|
||||
copy(inp, smem_inp, swizzle=16)
|
||||
t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16)
|
||||
t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True)
|
||||
t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
|
||||
copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
|
||||
|
||||
x = self.prng.integers(
|
||||
low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32
|
||||
).astype(jax_dtype_from)
|
||||
y = x.astype(jax_dtype_to)
|
||||
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
|
||||
np.testing.assert_array_equal(f(x), y)
|
||||
|
||||
@parameterized.product(
|
||||
jax_dtype_from_to=(
|
||||
(jnp.int8, jnp.bfloat16),
|
||||
(jnp.int4, jnp.bfloat16),
|
||||
),
|
||||
layout=(
|
||||
fa.WGMMA_LAYOUT,
|
||||
fa.WGMMA_LAYOUT_UPCAST_2X,
|
||||
fa.WGMMA_LAYOUT_UPCAST_4X,
|
||||
),
|
||||
)
|
||||
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
|
||||
def test_optimized_conversion(self, jax_dtype_from_to, layout):
|
||||
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
|
||||
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
|
||||
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
||||
m = 128
|
||||
@ -539,7 +577,7 @@ class WGMMALayoutTest(TestCase):
|
||||
smem_from,
|
||||
swizzle=128,
|
||||
is_signed=utils.is_signed(jax_dtype_from),
|
||||
layout=fa._tiled_wgmma_layout((m, n))
|
||||
layout=layout,
|
||||
)
|
||||
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
|
||||
t.store_tiled(smem_to, swizzle=128)
|
||||
@ -2175,19 +2213,11 @@ class LayoutTest(TestCase):
|
||||
.transpose(0, 2, 1, 3)
|
||||
)
|
||||
|
||||
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
||||
try:
|
||||
with jtu.capture_stdout() as get_sass:
|
||||
iota = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
||||
[expected, expected, mgpu.TMABarrier()],
|
||||
)(expected)
|
||||
finally:
|
||||
if prev_dump is not None:
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
||||
else:
|
||||
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
||||
with get_sass() as sass:
|
||||
iota = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
||||
[expected, expected, mgpu.TMABarrier()],
|
||||
)(expected)
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
# Verify that we don't use too many registers for the transfers.
|
||||
@ -2200,7 +2230,7 @@ class LayoutTest(TestCase):
|
||||
expected_regs //= 2
|
||||
for instr in ("STS", "LDS"):
|
||||
with self.subTest(instr + " count"):
|
||||
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
|
||||
addrs = re.findall(instr + r".* \[(.*)\]", sass())
|
||||
def get_reg(addr):
|
||||
if (pos := addr.find("+")) != -1:
|
||||
return addr[:pos]
|
||||
@ -2214,13 +2244,13 @@ class LayoutTest(TestCase):
|
||||
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
|
||||
m, n = 128, col_tiling * 2
|
||||
tiling = (64, col_tiling)
|
||||
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
|
||||
layout = fa.WGMMA_LAYOUT_UPCAST_2X
|
||||
def kernel(ctx, in_, out, smems):
|
||||
smem_in, smem_out, barrier = smems
|
||||
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
||||
barrier.wait()
|
||||
t = mgpu.FragmentedArray.load_tiled(
|
||||
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
|
||||
smem_in, swizzle=swizzle, is_signed=True, layout=layout
|
||||
)
|
||||
t.store_tiled(smem_out, swizzle=swizzle)
|
||||
mgpu.commit_shared()
|
||||
@ -2275,6 +2305,61 @@ class LayoutTest(TestCase):
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, y_ref)
|
||||
|
||||
@parameterized.parameters(
|
||||
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
|
||||
)
|
||||
def test_upcast_to_wgmma(
|
||||
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
|
||||
):
|
||||
in_dtype = jnp.dtype(in_dtype)
|
||||
out_dtype = jnp.dtype(jnp.int16)
|
||||
out_dtype_mlir = utils.dtype_to_ir_type(out_dtype)
|
||||
swizzle = 128
|
||||
in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits
|
||||
in_tiling = (8, in_col_tiling)
|
||||
out_col_tiling = swizzle // out_dtype.itemsize
|
||||
out_tiling = (8, out_col_tiling)
|
||||
m, n = 128, in_col_tiling * 2
|
||||
regs_per_thread = None
|
||||
def kernel(ctx, in_, out, smems):
|
||||
nonlocal regs_per_thread
|
||||
smem_in, smem_out, barrier = smems
|
||||
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
||||
barrier.wait()
|
||||
t = mgpu.FragmentedArray.load_tiled(
|
||||
smem_in, swizzle=swizzle, is_signed=True, layout=start_layout
|
||||
)
|
||||
regs_per_thread = t.registers.size
|
||||
t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True)
|
||||
t = t.to_layout(end_layout)
|
||||
t = t.astype(out_dtype_mlir, is_signed=True)
|
||||
t.store_tiled(smem_out, swizzle=swizzle)
|
||||
mgpu.commit_shared()
|
||||
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
||||
ctx.await_async_copy(0)
|
||||
def tile(x, tiling):
|
||||
return x.reshape(
|
||||
x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1]
|
||||
).transpose(0, 2, 1, 3)
|
||||
in_iinfo = jnp.iinfo(in_dtype)
|
||||
x = jax.random.randint(
|
||||
jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32
|
||||
).astype(in_dtype)
|
||||
xt = tile(x, in_tiling)
|
||||
y = x.astype(out_dtype)
|
||||
yt = tile(y, out_tiling)
|
||||
f = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
|
||||
)
|
||||
with get_sass() as sass:
|
||||
yt_kernel = f(xt)
|
||||
np.testing.assert_array_equal(yt_kernel, yt)
|
||||
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Tile:
|
||||
|
@ -25,8 +25,11 @@ from jax._src.interpreters import mlir as mlir_interpreter
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import fragmented_array as fa
|
||||
from jax.experimental.mosaic.gpu import inference_utils
|
||||
from jax.experimental.mosaic.gpu import layouts as layouts_lib
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -162,6 +165,259 @@ class TransformInferenceTest(parameterized.TestCase):
|
||||
)
|
||||
self.assertEmpty(inference_utils.out_transforms(async_store_op))
|
||||
|
||||
def test_infer_transforms_for_vector_load_op_derives_from_destination(self):
|
||||
vector_load_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
||||
def body(smem_ref):
|
||||
nonlocal vector_load_op
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
vector_load_op = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||
func.FuncOp.from_py_func(smem_ty)(body)
|
||||
|
||||
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
expected_transforms = ir.ArrayAttr.get([
|
||||
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||
])
|
||||
|
||||
self.assertSequenceEqual(
|
||||
inference_utils.in_transforms(vector_load_op), [expected_transforms]
|
||||
)
|
||||
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
|
||||
|
||||
def test_infer_transforms_for_vector_load_op_derives_from_source(self):
|
||||
vector_load_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
||||
def body(smem_ref):
|
||||
nonlocal vector_load_op
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
vector_load_op = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
|
||||
|
||||
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||
)
|
||||
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
self.assertSequenceEqual(
|
||||
inference_utils.in_transforms(vector_load_op), [transforms]
|
||||
)
|
||||
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
|
||||
|
||||
def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self):
|
||||
vector_load_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
||||
def body(smem_ref):
|
||||
nonlocal vector_load_op
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
vector_load_op = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
|
||||
|
||||
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
def test_infer_transforms_for_vector_store_op_derives_from_destination(self):
|
||||
vector_store_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
||||
def body(smem_ref, value_to_store):
|
||||
nonlocal vector_store_op
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
vector_store_op = vector.StoreOp(
|
||||
value_to_store, smem_ref, [zero] * len(shape)
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||
value_ty = ir.VectorType.get(shape, elt_ty)
|
||||
func.FuncOp.from_py_func(smem_ty, value_ty)(body)
|
||||
|
||||
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
expected_transforms = ir.ArrayAttr.get([
|
||||
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||
])
|
||||
|
||||
self.assertSequenceEqual(
|
||||
inference_utils.in_transforms(vector_store_op), [expected_transforms]
|
||||
)
|
||||
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
|
||||
|
||||
def test_infer_transforms_for_vector_store_op_derives_from_source(self):
|
||||
vector_store_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
||||
def body(smem_ref, value_to_store):
|
||||
nonlocal vector_store_op
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
vector_store_op = vector.StoreOp(
|
||||
value_to_store, smem_ref, [zero] * len(shape)
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||
value_ty = ir.VectorType.get(shape, elt_ty)
|
||||
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
|
||||
|
||||
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||
)
|
||||
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
self.assertSequenceEqual(
|
||||
inference_utils.in_transforms(vector_store_op), [transforms]
|
||||
)
|
||||
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
|
||||
|
||||
def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self):
|
||||
vector_store_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
||||
def body(smem_ref, value_to_store):
|
||||
nonlocal vector_store_op
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
vector_store_op = vector.StoreOp(
|
||||
value_to_store, smem_ref, [zero] * len(shape)
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
|
||||
value_ty = ir.VectorType.get(shape, elt_ty)
|
||||
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
|
||||
|
||||
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
|
||||
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
def test_infer_transforms_for_slice_smem_op_derives_from_user(self):
|
||||
slice_smem_op = vector_load_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
|
||||
def body(offset):
|
||||
nonlocal slice_smem_op, vector_load_op
|
||||
slice_smem_op = mgpu.dialect.SliceSMEMOp(
|
||||
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
|
||||
)
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
load_offsets = [zero] * len(shape)
|
||||
vector_load_op = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
|
||||
|
||||
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
expected_transforms = ir.ArrayAttr.get([
|
||||
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||
])
|
||||
|
||||
self.assertEmpty(inference_utils.in_transforms(slice_smem_op))
|
||||
self.assertSequenceEqual(
|
||||
inference_utils.out_transforms(slice_smem_op), [expected_transforms]
|
||||
)
|
||||
|
||||
def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self):
|
||||
slice_smem_op = vector_load_op1 = vector_load_op2 = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
|
||||
def body(offset):
|
||||
nonlocal slice_smem_op, vector_load_op1, vector_load_op2
|
||||
slice_smem_op = mgpu.dialect.SliceSMEMOp(
|
||||
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
|
||||
)
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
load_offsets = [zero] * len(shape)
|
||||
vector_load_op1 = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||
)
|
||||
vector_load_op2 = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
|
||||
|
||||
vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||
)
|
||||
vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get(
|
||||
[ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])]
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Mosaic GPU CUPTI-based profiler."""
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for nn module."""
|
||||
|
||||
import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the optimizers module."""
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for pull block spec."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Pallas indexing logic and abstractions."""
|
||||
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import unittest
|
||||
|
@ -185,7 +185,7 @@ class PallasCallTest(PallasTest):
|
||||
np.testing.assert_array_equal(kernel(x, y), x + y[0])
|
||||
|
||||
@parameterized.product(
|
||||
shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics]
|
||||
shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics]
|
||||
)
|
||||
def test_reduce_sum(self, shape, thread_semantics):
|
||||
@functools.partial(
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for common JAX operations within pallas_call."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
import itertools
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Pallas error handling."""
|
||||
import functools
|
||||
import traceback
|
||||
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for TPU specific operations within pallas_call."""
|
||||
|
||||
import functools
|
||||
import math
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for distributed pallas TPU operations."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for random ops in Pallas + Mosaic."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Pallas mesh API."""
|
||||
import functools
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for splash_attention."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for splash_attention_masks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for interoperability between JAX and pickling libraries."""
|
||||
|
||||
import pickle
|
||||
import unittest
|
||||
|
@ -6138,6 +6138,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
|
||||
{AxisType.Auto: ('x',)})
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_device_put_use_mesh(self, mesh):
|
||||
out = jax.device_put(np.arange(8), P('x'))
|
||||
self.assertArraysEqual(out, np.arange(8))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
def test_device_put_no_use_mesh_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is'
|
||||
' passed to device_put'):
|
||||
jax.device_put(np.arange(8), P('x'))
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_inputs_different_context(self, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""Tests for the library of QDWH-based polar decomposition."""
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
||||
self._CheckChiSquared(samples, pmf=pmf)
|
||||
|
||||
@jtu.sample_product(
|
||||
logits_shape=[(7,), (8, 9), (10, 11, 12)],
|
||||
prefix_shape=[(2,), (3, 4), (5, 6)],
|
||||
)
|
||||
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
|
||||
key = random.key(0)
|
||||
|
||||
key, subkey = random.split(key)
|
||||
logits = random.normal(subkey, logits_shape)
|
||||
|
||||
key, subkey = random.split(key)
|
||||
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
|
||||
|
||||
dists_shape = tuple(np.delete(logits_shape, axis))
|
||||
n_categories = logits_shape[axis]
|
||||
shape = prefix_shape + dists_shape
|
||||
prefix_size = math.prod(prefix_shape)
|
||||
|
||||
if n_categories < prefix_size:
|
||||
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
|
||||
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||
|
||||
else:
|
||||
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||
self.assertEqual(output.shape, shape)
|
||||
assert (0 <= output).all()
|
||||
assert (output < n_categories).all()
|
||||
flat = output.reshape((prefix_size, math.prod(dists_shape)))
|
||||
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
|
||||
assert (counts <= 1).all()
|
||||
|
||||
|
||||
def testBernoulliShape(self):
|
||||
key = self.make_key(0)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the shape-polymorphic export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -12,9 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Tests for stack."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Stax library."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""Tests for the library of QDWH-based singular value decomposition."""
|
||||
import functools
|
||||
|
||||
import jax
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for transfer guards."""
|
||||
|
||||
import contextlib
|
||||
import pickle
|
||||
|
@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase):
|
||||
util.safe_zip((), range(3))
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, parents):
|
||||
self.parents = parents
|
||||
|
||||
|
||||
class TopologicalSortTest(jtu.JaxTestCase):
|
||||
|
||||
def _check_topological_sort(self, nodes, order):
|
||||
self.assertEqual(sorted(nodes, key=id), sorted(order, key=id))
|
||||
visited = set()
|
||||
for node in nodes:
|
||||
self.assertTrue(all(id(parent) in visited for parent in node.parents))
|
||||
visited.add(id(node))
|
||||
|
||||
def test_basic(self):
|
||||
a = Node([])
|
||||
b = Node([a])
|
||||
c = Node([a])
|
||||
d = Node([a, c])
|
||||
e = Node([b, c])
|
||||
out = util.toposort([a, d, e])
|
||||
self._check_topological_sort([a, b, c, d, e], out)
|
||||
|
||||
def test_stick(self):
|
||||
a = Node([])
|
||||
b = Node([a])
|
||||
c = Node([b])
|
||||
d = Node([c])
|
||||
e = Node([d])
|
||||
out = util.toposort([e])
|
||||
self._check_topological_sort([a, b, c, d, e], out)
|
||||
|
||||
def test_diamonds(self):
|
||||
a = Node([])
|
||||
b = Node([a])
|
||||
c = Node([a])
|
||||
d = Node([b, c])
|
||||
e = Node([d])
|
||||
f = Node([d])
|
||||
g = Node([e, f])
|
||||
out = util.toposort([g])
|
||||
self._check_topological_sort([a, b, c, d, e, f, g], out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "4c4aa96f9ffec4bb963b50c50192aeab4da9dc4a"
|
||||
XLA_SHA256 = "c373e52b2f8b4175c69e99e636ad64b3bcf33fb44d1b7ad6ef8f4162c9052af8"
|
||||
XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438"
|
||||
XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user