mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #24197 from yhtang:add-k8s-ci
PiperOrigin-RevId: 743302226
This commit is contained in:
commit
c8273d7795
101
.github/workflows/k8s.yaml
vendored
Normal file
101
.github/workflows/k8s.yaml
vendored
Normal file
@ -0,0 +1,101 @@
|
||||
name: Distributed run using K8s Jobset
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -ex -o pipefail {0}
|
||||
|
||||
jobs:
|
||||
|
||||
distributed-initialize:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
path: jax
|
||||
|
||||
- name: Start Minikube cluster
|
||||
uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18
|
||||
|
||||
- name: Install K8s Jobset
|
||||
run: |
|
||||
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml
|
||||
|
||||
- name: Build image
|
||||
run: |
|
||||
cat > Dockerfile <<EOF
|
||||
FROM ubuntu:22.04
|
||||
ADD jax /opt/jax
|
||||
RUN apt-get update && apt-get install -y python-is-python3 python3-pip
|
||||
RUN pip install -e /opt/jax[k8s]
|
||||
EOF
|
||||
|
||||
minikube image build -t local/jax:latest .
|
||||
|
||||
- name: Create service account for K8s job introspection
|
||||
run: |
|
||||
kubectl apply -f jax/examples/k8s/svc-acct.yaml
|
||||
|
||||
- name: Prepare test job
|
||||
run: |
|
||||
export VERSION=v4.44.3
|
||||
export BINARY=yq_linux_amd64
|
||||
wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq
|
||||
|
||||
cat jax/examples/k8s/example.yaml |\
|
||||
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\
|
||||
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\
|
||||
tee example.yaml
|
||||
|
||||
- name: Submit test job
|
||||
run: |
|
||||
kubectl apply -f example.yaml
|
||||
|
||||
- name: Check job status
|
||||
shell: bash -e -o pipefail {0}
|
||||
run: |
|
||||
while true; do
|
||||
status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type)
|
||||
timestamp=$(date +"%Y-%m-%d %H:%M:%S")
|
||||
echo "[$timestamp] Checking job status..."
|
||||
|
||||
if [ "$status" == "Completed" ]; then
|
||||
echo "[$timestamp] Job has completed successfully!"
|
||||
exit 0
|
||||
elif [ "$status" == "Failed" ]; then
|
||||
echo "[$timestamp] Job has failed!"
|
||||
exit 1
|
||||
else
|
||||
echo "[$timestamp] Job is still running. Current pod status:"
|
||||
kubectl get pods --no-headers
|
||||
echo "[$timestamp] Waiting for 3 seconds before checking again..."
|
||||
sleep 3
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Examine individual pod outputs
|
||||
if: "!cancelled()"
|
||||
run: |
|
||||
set +x
|
||||
kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do
|
||||
echo "========================================"
|
||||
echo "Pod $pod output:"
|
||||
echo "----------------------------------------"
|
||||
kubectl logs $pod
|
||||
echo "========================================"
|
||||
done
|
@ -15,6 +15,7 @@ repos:
|
||||
- id: check-merge-conflict
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
exclude: examples/k8s/svc-acct.yaml
|
||||
- id: end-of-file-fixer
|
||||
# only include python files
|
||||
files: \.py$
|
||||
|
40
examples/k8s/example.yaml
Normal file
40
examples/k8s/example.yaml
Normal file
@ -0,0 +1,40 @@
|
||||
apiVersion: jobset.x-k8s.io/v1alpha2
|
||||
kind: JobSet
|
||||
metadata:
|
||||
name: example
|
||||
spec:
|
||||
replicatedJobs:
|
||||
- name: workers
|
||||
template:
|
||||
spec:
|
||||
parallelism: 2
|
||||
completions: 2
|
||||
backoffLimit: 0
|
||||
template:
|
||||
spec:
|
||||
serviceAccountName: training-job-sa
|
||||
restartPolicy: Never
|
||||
imagePullSecrets:
|
||||
- name: null
|
||||
containers:
|
||||
- name: main
|
||||
image: PLACEHOLDER
|
||||
imagePullPolicy: IfNotPresent
|
||||
resources:
|
||||
requests:
|
||||
cpu: 900m
|
||||
nvidia.com/gpu: null
|
||||
limits:
|
||||
cpu: 1
|
||||
nvidia.com/gpu: null
|
||||
command:
|
||||
- python
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import jax
|
||||
jax.distributed.initialize()
|
||||
print(jax.devices())
|
||||
print(jax.local_devices())
|
||||
assert jax.process_count() > 1
|
||||
assert len(jax.devices()) > len(jax.local_devices())
|
31
examples/k8s/svc-acct.yaml
Normal file
31
examples/k8s/svc-acct.yaml
Normal file
@ -0,0 +1,31 @@
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: training-job-sa
|
||||
namespace: default
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: Role
|
||||
metadata:
|
||||
name: pod-reader
|
||||
rules:
|
||||
- apiGroups: [""]
|
||||
resources: ["pods"]
|
||||
verbs: ["get", "list", "watch"]
|
||||
- apiGroups: ["batch"]
|
||||
resources: ["jobs"]
|
||||
verbs: ["get", "list", "watch"]
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: RoleBinding
|
||||
metadata:
|
||||
name: pod-reader-binding
|
||||
namespace: default
|
||||
subjects:
|
||||
- kind: ServiceAccount
|
||||
name: training-job-sa
|
||||
namespace: default
|
||||
roleRef:
|
||||
kind: Role
|
||||
name: pod-reader
|
||||
apiGroup: rbac.authorization.k8s.io
|
@ -35,15 +35,17 @@ class K8sCluster(clusters.ClusterEnv):
|
||||
try:
|
||||
import kubernetes as k8s # pytype: disable=import-error
|
||||
except ImportError as e:
|
||||
warnings.warn(textwrap.fill(
|
||||
"Kubernetes environment detected, but the `kubernetes` package is "
|
||||
"not installed to enable automatic bootstrapping in this "
|
||||
"environment. To enable automatic boostrapping, please install "
|
||||
"jax with the [k8s] extra. For example:"
|
||||
" pip install jax[k8s]"
|
||||
" OR"
|
||||
" pip install jax[k8s,<MORE-EXTRAS...>]"
|
||||
))
|
||||
warnings.warn(
|
||||
'\n'.join([
|
||||
textwrap.fill(
|
||||
"Kubernetes environment detected, but the `kubernetes` package "
|
||||
"is not installed to enable automatic bootstrapping in this "
|
||||
"environment. To enable automatic boostrapping, please install "
|
||||
"jax with the [k8s] extra. For example:"),
|
||||
" pip install jax[k8s]",
|
||||
" pip install jax[k8s,<MORE-EXTRAS...>]",
|
||||
])
|
||||
)
|
||||
return False
|
||||
|
||||
k8s.config.load_incluster_config()
|
||||
|
Loading…
x
Reference in New Issue
Block a user