mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add unit test and move to _src/cudnn dir
This commit is contained in:
parent
d1141b4058
commit
9b8a100039
@ -32,6 +32,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding
|
||||
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import cuda_versions
|
||||
|
||||
Array = jnp.ndarray
|
||||
DType = jnp.dtype
|
||||
@ -120,6 +121,50 @@ def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
]
|
||||
return _custom_name_maps[index]
|
||||
|
||||
def check_qkv_layout(query, key, value):
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
|
||||
"query, key and value should have rank 4."
|
||||
|
||||
# Only support fp16 and bf16 here
|
||||
query_dtype = query.dtype
|
||||
key_dtype = key.dtype
|
||||
value_dtype = value.dtype
|
||||
assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \
|
||||
"query, key and value should have same dtype and should be float16 or bfloat16"
|
||||
|
||||
q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape
|
||||
k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape
|
||||
v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape
|
||||
assert (q_batch == k_batch == v_batch) \
|
||||
and (k_seq_len == v_seq_len) \
|
||||
and (q_num_heads == k_num_heads == v_num_heads) \
|
||||
and (q_head_dim == k_head_dim == v_head_dim), \
|
||||
"query should have layout [batch, q_seq, num_heads, head_dim], " \
|
||||
"key and value should have layout [batch, kv_seq, num_heads, head_dim]."
|
||||
|
||||
def check_is_flash_attention(query, key):
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_sqe_len, _, _ = key.shape
|
||||
# check if attention pattern is supported by flash attention or fused attention
|
||||
if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]:
|
||||
# check if flash attention is supported
|
||||
is_flash_attention = True
|
||||
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
|
||||
# check if regular fused attention is supported
|
||||
is_flash_attention = False
|
||||
else:
|
||||
raise NotImplementedError("Unsupported sequence length and head dim.")
|
||||
return is_flash_attention
|
||||
|
||||
def check_cuDNN_version(is_flash_attention):
|
||||
# check if cuDNN is installed and if cuDNN version contraint is satisfied
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903:
|
||||
raise RuntimeError("Require cuDNN at lease 8.9.3 to run flash attention.")
|
||||
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
|
||||
raise RuntimeError("Require cuDNN at lease 8.9.1 to run fused attention.")
|
||||
|
||||
def _dot_product_attention_fwd(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, _ = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
@ -170,15 +215,6 @@ def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, f
|
||||
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
# Q, K and V must have the same data type
|
||||
assert query_dtype == key_dtype == value_dtype
|
||||
# Only support fp16 and bf16 here
|
||||
assert query_dtype in [jnp.float16, jnp.bfloat16]
|
||||
# Q, K and V must be 4-D tensors
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
output_shape = (batch, q_seq_len, num_heads, head_dim)
|
||||
@ -201,12 +237,7 @@ def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activatio
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
# Q, K and V must have the same data type
|
||||
assert query_dtype == key_dtype == value_dtype
|
||||
# Only support fp16 and bf16 here
|
||||
assert query_dtype in [jnp.float16, jnp.bfloat16]
|
||||
# Q, K and V must be 4-D tensors
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4
|
||||
|
||||
return (
|
||||
ShapedArray(
|
||||
query.shape, query_dtype
|
||||
@ -587,7 +618,7 @@ def _dot_product_attention(query: Array,
|
||||
|
||||
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
|
||||
_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule)
|
||||
|
||||
|
||||
# User interface
|
||||
def dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
@ -618,10 +649,13 @@ def dot_product_attention(query: Array,
|
||||
Returns:
|
||||
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
|
||||
"""
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
is_flash_attention = False
|
||||
if q_seq_len > 512:
|
||||
is_flash_attention = True
|
||||
# check if query, key and value layout meets cuDNN layout requirement
|
||||
check_qkv_layout(query, key, value)
|
||||
# check if flash attention is supported for this attention pattern
|
||||
is_flash_attention = check_is_flash_attention(query, key)
|
||||
# check if cuDNN is installed and if cuDNN version is sufficient
|
||||
check_cuDNN_version(is_flash_attention)
|
||||
|
||||
variadic_args = (bias is not None, mask is not None)
|
||||
if bias is None:
|
||||
bias = jnp.zeros(0, dtype=query.dtype)
|
209
jax/_src/cudnn/fused_attention_stableHLO_test.py
Normal file
209
jax/_src/cudnn/fused_attention_stableHLO_test.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from absl.testing import absltest
|
||||
from typing import Any, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stableHLO import dot_product_attention
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
Array = jnp.ndarray
|
||||
|
||||
def f(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
output = dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale=scale,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
is_causal_mask=causal_mask,
|
||||
dropout_rate=dropout_rate)
|
||||
return output
|
||||
|
||||
def f_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
out, f_vjp = jax.vjp(
|
||||
partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
|
||||
query, key, value, bias, None)
|
||||
query_grad, key_grad, value_grad, _, _ = f_vjp(grad)
|
||||
return out, (query_grad, key_grad, value_grad)
|
||||
|
||||
def g(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
def get_large_negative_number(dtype):
|
||||
if jnp.issubdtype(dtype, jnp.inexact):
|
||||
dtype_max = jnp.finfo(dtype).max
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
dtype_max = jnp.iinfo(dtype).max
|
||||
else:
|
||||
raise ValueError('Unsupported dtype for inputs.')
|
||||
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
|
||||
|
||||
def get_causal_mask(input_t):
|
||||
large_negative_number = get_large_negative_number(input_t.dtype)
|
||||
t = input_t.shape[2]
|
||||
col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
|
||||
row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
|
||||
mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
|
||||
return mask[jnp.newaxis, jnp.newaxis, :, :]
|
||||
|
||||
if scale != 1.0:
|
||||
query = query * scale
|
||||
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
|
||||
if causal_mask:
|
||||
bias = get_causal_mask(attn_weights)
|
||||
if bias is not None:
|
||||
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
|
||||
attn_weights = jax.nn.softmax(attn_weights)
|
||||
if dropout_rate > 0.:
|
||||
keep_prob = 1.0 - dropout_rate
|
||||
dropout_shape = list(attn_weights.shape)
|
||||
dropout_shape[-2] = 1
|
||||
dropout_rng = jax.random.PRNGKey(0)
|
||||
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
||||
keep = jnp.broadcast_to(keep, attn_weights.shape)
|
||||
multiplier = (
|
||||
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
|
||||
attn_weights = attn_weights * multiplier
|
||||
|
||||
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
|
||||
|
||||
def g_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
out_ref, g_vjp = jax.vjp(
|
||||
partial(g, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
|
||||
query, key, value, bias, None)
|
||||
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = g_vjp(grad)
|
||||
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key='allow')
|
||||
class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
batch_size=[4],
|
||||
seq_len=[256, 1024],
|
||||
num_heads=[8],
|
||||
head_dim=[64, 128],
|
||||
use_bias=[True],
|
||||
is_causal_mask=[False],
|
||||
dropout_rate=[0],
|
||||
scale=[0.5],
|
||||
dtype=[jnp.float16, jnp.bfloat16]
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
|
||||
head_dim: int, use_bias: bool, is_causal_mask: bool,
|
||||
dropout_rate: float, scale: float, dtype: jnp.dtype):
|
||||
if (seq_len == 256 and is_causal_mask):
|
||||
self.skipTest("Fused attention does not support mask generation.")
|
||||
if (seq_len == 256 and head_dim == 128):
|
||||
self.skipTest("Fused attention does not head dim = 128.")
|
||||
if len(jax.local_devices()) <= 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
os.environ['XLA_FLAGS'] = '--xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
|
||||
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
|
||||
query = jax.random.normal(
|
||||
k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
|
||||
key = jax.random.normal(
|
||||
k2, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
|
||||
value = jax.random.normal(
|
||||
k3, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
|
||||
grad = jax.random.normal(
|
||||
k4, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
|
||||
if use_bias:
|
||||
bias = jax.random.normal(
|
||||
k5, (batch_size, num_heads, seq_len, seq_len), dtype=dtype)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
jitted_f_train = jax.jit(partial(f_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
|
||||
jitted_g_train = jax.jit(partial(g_train, causal_mask=is_causal_mask, scale=scale, dropout_rate=dropout_rate))
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ('dp', 'tp')) as mesh:
|
||||
qkv_spec = PartitionSpec('dp', None, 'tp', None)
|
||||
if bias is not None:
|
||||
bias_spec = PartitionSpec('dp', 'tp', None, None)
|
||||
else:
|
||||
bias_spec = None
|
||||
query = jax.device_put(query, NamedSharding(mesh, qkv_spec))
|
||||
key = jax.device_put(key, NamedSharding(mesh, qkv_spec))
|
||||
value = jax.device_put(value, NamedSharding(mesh, qkv_spec))
|
||||
if bias is not None:
|
||||
bias = jax.device_put(bias, NamedSharding(mesh, bias_spec))
|
||||
grad = jax.device_put(grad, NamedSharding(mesh, qkv_spec))
|
||||
in_shardings = (qkv_spec, qkv_spec, qkv_spec, qkv_spec, bias_spec, None)
|
||||
out_shardings = (None, (qkv_spec, qkv_spec, qkv_spec))
|
||||
pjitted_f_train = pjit(jitted_f_train,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
pjitted_g_train = pjit(jitted_g_train,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
out, (query_grad, key_grad, value_grad) = pjitted_g_train(query, key, value, grad, bias, None)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
|
||||
assert jnp.allclose(out_ref, out, rtol=1e-5, atol=1e-5)
|
||||
if seq_len > 512:
|
||||
# query_grad in flash attention is not deterministic
|
||||
assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2)
|
||||
else:
|
||||
assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
|
||||
assert jnp.allclose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
|
||||
assert jnp.allclose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -1,122 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
from flax import linen as nn
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.experimental.fused_attention_stableHLO import dot_product_attention
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax import core, vmap
|
||||
from jax import make_jaxpr
|
||||
|
||||
Array = jnp.ndarray
|
||||
DType = jnp.dtype
|
||||
PRNGKey = jnp.ndarray
|
||||
|
||||
def f(input_q: Array,
|
||||
input_k: Array,
|
||||
input_v: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
output = dot_product_attention(
|
||||
input_q,
|
||||
input_k,
|
||||
input_v,
|
||||
scale=scale,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
is_cauasl_mask=causal_mask,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
def train_step(input_q, input_k, input_v, bias, mask, grad, scale, causal_mask, dropout_rate):
|
||||
out, f_vjp = jax.vjp(
|
||||
vmap(partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate), in_axes=(0, 0, 0, 0, 0)),
|
||||
input_q, input_k, input_v, bias, mask
|
||||
)
|
||||
input_q_grad, input_k_grad, input_v_grad, bias_grad, _ = f_vjp(grad)
|
||||
return out, (input_q_grad, input_k_grad, input_v_grad, bias_grad, _)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='T5X MHA Unit Test')
|
||||
parser.add_argument("--batch_size", dest="batch_size", type=int, default=16)
|
||||
parser.add_argument("--q_seq_len", dest="q_seq_len", type=int, default=1024)
|
||||
parser.add_argument("--kv_seq_len", dest="kv_seq_len", type=int, default=1024)
|
||||
parser.add_argument("--num_attn_heads", dest="num_attn_heads", type=int, default=16)
|
||||
parser.add_argument("--head_dim", dest="head_dim", type=int, default=64)
|
||||
parser.add_argument("--scale", dest="scale", type=float, default=1.0)
|
||||
parser.add_argument("--dropout_rate", dest="dropout_rate", type=float, default=0.)
|
||||
parser.add_argument("--bias", action="store_true")
|
||||
parser.add_argument("--mask", action="store_true")
|
||||
parser.add_argument("--causal_mask", action="store_true")
|
||||
parser.add_argument("--fwd_only", dest='fwd_only', action="store_true")
|
||||
parser.add_argument("--xla_dump_to", dest="xla_dump_to", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.xla_dump_to:
|
||||
xla_flags = f'--xla_dump_hlo_pass_re=.* --xla_dump_hlo_as_text --xla_dump_to={args.xla_dump_to}'
|
||||
|
||||
os.environ['XLA_FLAGS'] = xla_flags
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
key = jax.random.PRNGKey(0)
|
||||
key1, key2 = jax.random.split(key, 2)
|
||||
input_q = jax.random.uniform(key2, (4, args.batch_size, args.q_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
|
||||
input_k = jax.random.uniform(key2, (4, args.batch_size, args.kv_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
|
||||
input_v = jax.random.uniform(key2, (4, args.batch_size, args.kv_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
|
||||
bias = jax.random.uniform(key2, (4, args.batch_size, args.num_attn_heads, args.q_seq_len, args.kv_seq_len), dtype=dtype) if args.bias else None
|
||||
mask = jax.random.uniform(key2, (4, args.batch_size, args.num_attn_heads, args.q_seq_len, args.kv_seq_len), dtype=dtype) if args.mask else None
|
||||
grad = jax.random.uniform(key2, (4, args.batch_size, args.q_seq_len, args.num_attn_heads, args.head_dim), dtype=dtype)
|
||||
|
||||
if args.fwd_only:
|
||||
jitted_fwd_step = jax.jit(vmap(partial(f, causal_mask=args.causal_mask, scale=args.scale, dropout_rate=0), in_axes=(0, 0, 0, 0, 0)))
|
||||
out = jitted_fwd_step(input_q, input_k, input_v, bias, mask)
|
||||
print(out[0,0,0,0,:20])
|
||||
return out
|
||||
else:
|
||||
jitted_train_step = jax.jit(partial(train_step, causal_mask=args.causal_mask, scale=args.scale, dropout_rate=args.dropout_rate))
|
||||
# out, input_grad = jitted_train_step(input_q, input_k, input_v, bias, mask, grad)
|
||||
# print(input_grad[0][0,0,0,:20])
|
||||
# return out, input_grad
|
||||
devices = np.array(jax.local_devices())
|
||||
devices = devices.reshape((2, 4, 1, 1, 1))
|
||||
with Mesh(devices, ('p', 'b', 's', 'n', 'd')) as mesh:
|
||||
input_q = jax.device_put(input_q, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
|
||||
input_k = jax.device_put(input_k, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
|
||||
input_v = jax.device_put(input_v, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
|
||||
bias = jax.device_put(bias, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
|
||||
mask = jax.device_put(mask, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
|
||||
grad = jax.device_put(grad, NamedSharding(mesh, PartitionSpec('p', 'b', None, None, None)))
|
||||
|
||||
pjitter = pjit(jitted_train_step,
|
||||
in_shardings=(
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None)),
|
||||
out_shardings=(None,
|
||||
(PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
PartitionSpec('p', 'b', None, None, None),
|
||||
None,
|
||||
None))
|
||||
)
|
||||
|
||||
out, grads = pjitter(input_q, input_k, input_v, bias, mask, grad)
|
||||
# print(make_jaxpr(pjitter)(input_q, input_k, input_v, bias, mask, grad))
|
||||
print(grads[0][0,0,0,0,:20])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user