mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Adds Pallas flash attention TPU kernel. Implementation based on https://arxiv.org/pdf/2205.14135.pdf.
PiperOrigin-RevId: 560346791
This commit is contained in:
parent
08ca945271
commit
841baabd3f
14
jax/BUILD
14
jax/BUILD
@ -496,6 +496,7 @@ pytype_strict_library(
|
||||
exclude = [
|
||||
"experimental/pallas/gpu.py",
|
||||
"experimental/pallas/tpu.py",
|
||||
"experimental/pallas/ops/tpu/*.py",
|
||||
],
|
||||
),
|
||||
visibility = [
|
||||
@ -521,6 +522,19 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_tpu_ops",
|
||||
srcs = glob(["experimental/pallas/ops/tpu/*.py"]),
|
||||
visibility = [
|
||||
":pallas_tpu_users",
|
||||
],
|
||||
deps = [
|
||||
":jax",
|
||||
":pallas",
|
||||
":pallas_tpu",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_gpu",
|
||||
srcs = ["experimental/pallas/gpu.py"],
|
||||
|
13
jax/experimental/pallas/ops/tpu/__init__.py
Normal file
13
jax/experimental/pallas/ops/tpu/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
1147
jax/experimental/pallas/ops/tpu/flash_attention.py
Normal file
1147
jax/experimental/pallas/ops/tpu/flash_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user