Merge pull request #20128 from shuhand0:dev/shuhan/ci2

PiperOrigin-RevId: 615231412
This commit is contained in:
jax authors 2024-03-12 17:45:19 -07:00
commit 60bf38bde9
4 changed files with 5866 additions and 0 deletions

42
.github/workflows/metal_plugin_ci.yml vendored Normal file
View File

@ -0,0 +1,42 @@
# JAX-Metal plugin CI
name: Jax-Metal CI
on:
workflow_dispatch: # allows triggering the workflow run manually
jobs:
jax-metal-plugin-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["plugin_latest"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Get repo
uses: actions/checkout@v4
with:
path: jax
- name: Setup build and test enviroment
run: |
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install jax-metal absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
fi;
cd jax
pip install .
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
export ENABLE_PJRT_COMPATIBILITY=1
cd jax
pytest tests/lax_metal_test.py

View File

@ -349,6 +349,8 @@ def supported_dtypes():
elif device_under_test() == "iree":
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
np.uint32, np.float32}
elif device_under_test() == "METAL":
types = {np.int32, np.uint32, np.float32}
else:
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
@ -423,6 +425,8 @@ def _get_device_tags():
device_tags = {device_under_test(), "rocm"}
elif is_device_cuda():
device_tags = {device_under_test(), "cuda"}
elif device_under_test() == "METAL":
device_tags = {device_under_test(), "gpu"}
else:
device_tags = {device_under_test()}
return device_tags

View File

@ -542,6 +542,21 @@ jax_test(
] + py_deps("numpy"),
)
jax_test(
name = "lax_metal_test",
srcs = ["lax_metal_test.py"],
disable_backends = [
"cpu",
"gpu",
"tpu",
],
tags = ["notap"],
deps = [
"//jax:internal_test_util",
"//jax:lax_reference",
] + py_deps("numpy"),
)
jax_test(
name = "lax_autodiff_test",
srcs = ["lax_autodiff_test.py"],

5805
tests/lax_metal_test.py Normal file

File diff suppressed because it is too large Load Diff