mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #20128 from shuhand0:dev/shuhan/ci2
PiperOrigin-RevId: 615231412
This commit is contained in:
commit
60bf38bde9
42
.github/workflows/metal_plugin_ci.yml
vendored
Normal file
42
.github/workflows/metal_plugin_ci.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
# JAX-Metal plugin CI
|
||||
|
||||
name: Jax-Metal CI
|
||||
on:
|
||||
workflow_dispatch: # allows triggering the workflow run manually
|
||||
|
||||
jobs:
|
||||
jax-metal-plugin-test:
|
||||
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
jaxlib-version: ["plugin_latest"]
|
||||
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
steps:
|
||||
- name: Get repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
path: jax
|
||||
- name: Setup build and test enviroment
|
||||
run: |
|
||||
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
|
||||
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
|
||||
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
|
||||
pip install -U pip numpy wheel
|
||||
pip install jax-metal absl-py pytest
|
||||
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
|
||||
pip install --pre jaxlib \
|
||||
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
fi;
|
||||
cd jax
|
||||
pip install .
|
||||
- name: Run test
|
||||
run: |
|
||||
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
|
||||
export ENABLE_PJRT_COMPATIBILITY=1
|
||||
cd jax
|
||||
pytest tests/lax_metal_test.py
|
||||
|
||||
|
@ -349,6 +349,8 @@ def supported_dtypes():
|
||||
elif device_under_test() == "iree":
|
||||
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
|
||||
np.uint32, np.float32}
|
||||
elif device_under_test() == "METAL":
|
||||
types = {np.int32, np.uint32, np.float32}
|
||||
else:
|
||||
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
@ -423,6 +425,8 @@ def _get_device_tags():
|
||||
device_tags = {device_under_test(), "rocm"}
|
||||
elif is_device_cuda():
|
||||
device_tags = {device_under_test(), "cuda"}
|
||||
elif device_under_test() == "METAL":
|
||||
device_tags = {device_under_test(), "gpu"}
|
||||
else:
|
||||
device_tags = {device_under_test()}
|
||||
return device_tags
|
||||
|
15
tests/BUILD
15
tests/BUILD
@ -542,6 +542,21 @@ jax_test(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_metal_test",
|
||||
srcs = ["lax_metal_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:internal_test_util",
|
||||
"//jax:lax_reference",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_autodiff_test",
|
||||
srcs = ["lax_autodiff_test.py"],
|
||||
|
5805
tests/lax_metal_test.py
Normal file
5805
tests/lax_metal_test.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user