mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix import for Windows platforms
PiperOrigin-RevId: 720348679
This commit is contained in:
parent
faaaf82974
commit
4fe937683e
@ -33,16 +33,16 @@ if sys.platform != "win32":
|
||||
from jax.experimental.pallas.ops.gpu import layer_norm
|
||||
from jax.experimental.pallas.ops.gpu import rms_norm
|
||||
from jax.experimental.pallas.ops.gpu import softmax
|
||||
BlockSizes = attention.BlockSizes
|
||||
else:
|
||||
attention = None
|
||||
layer_norm = None
|
||||
rms_norm = None
|
||||
softmax = None
|
||||
BlockSizes = None
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
BlockSizes = attention.BlockSizes
|
||||
|
||||
# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
|
||||
# pylint: disable=no-value-for-parameter
|
||||
|
||||
@ -155,9 +155,9 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
num_heads=(1, 2, 8),
|
||||
head_dim=(32, 64, 128),
|
||||
block_sizes=(
|
||||
BlockSizes.get_default(),
|
||||
BlockSizes(block_q=64,block_k=64),
|
||||
BlockSizes(block_q=64,block_k=128),
|
||||
(("block_q", 128), ("block_k", 128)),
|
||||
(("block_q", 64), ("block_k", 64)),
|
||||
(("block_q", 64), ("block_k", 128)),
|
||||
),
|
||||
causal=(True, False),
|
||||
use_fwd=(True, False),
|
||||
@ -199,7 +199,7 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
v, _ = jax.vjp(
|
||||
functools.partial(
|
||||
attention.mha,
|
||||
block_sizes=block_sizes,
|
||||
block_sizes=BlockSizes(**dict(block_sizes)),
|
||||
causal=causal,
|
||||
segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET,
|
||||
@ -213,7 +213,7 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
else:
|
||||
impl = functools.partial(
|
||||
attention.mha,
|
||||
block_sizes=block_sizes,
|
||||
block_sizes=BlockSizes(**dict(block_sizes)),
|
||||
causal=causal,
|
||||
segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET,
|
||||
@ -228,23 +228,30 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
num_heads=(1, 2),
|
||||
head_dim=(32, 64, 128,),
|
||||
block_sizes=(
|
||||
BlockSizes.get_default(),
|
||||
BlockSizes(
|
||||
block_q=128,
|
||||
block_k=128,
|
||||
block_q_dkv=64,
|
||||
block_kv_dkv=64,
|
||||
block_q_dq=64,
|
||||
block_kv_dq=64,
|
||||
),
|
||||
BlockSizes(
|
||||
block_q=128,
|
||||
block_k=128,
|
||||
block_q_dkv=64,
|
||||
block_kv_dkv=128,
|
||||
block_q_dq=128,
|
||||
block_kv_dq=64,
|
||||
),
|
||||
(
|
||||
("block_q", 128),
|
||||
("block_k", 128),
|
||||
("block_q_dkv", 128),
|
||||
("block_kv_dkv", 128),
|
||||
("block_q_dq", 128),
|
||||
("block_kv_dq", 128),
|
||||
),
|
||||
(
|
||||
("block_q", 64),
|
||||
("block_k", 64),
|
||||
("block_q_dkv", 64),
|
||||
("block_kv_dkv", 64),
|
||||
("block_q_dq", 64),
|
||||
("block_kv_dq", 64),
|
||||
),
|
||||
(
|
||||
("block_q", 64),
|
||||
("block_k", 128),
|
||||
("block_q_dkv", 64),
|
||||
("block_kv_dkv", 128),
|
||||
("block_q_dq", 128),
|
||||
("block_kv_dq", 64),
|
||||
),
|
||||
),
|
||||
causal=(True, False),
|
||||
use_segment_ids=(True, False),
|
||||
@ -280,7 +287,7 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
def f(q, k, v):
|
||||
return attention.mha(
|
||||
q, k, v,
|
||||
block_sizes=block_sizes,
|
||||
block_sizes=BlockSizes(**dict(block_sizes)),
|
||||
causal=causal,
|
||||
segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET).sum()
|
||||
|
Loading…
x
Reference in New Issue
Block a user