Merge pull request #24197 from yhtang:add-k8s-ci

PiperOrigin-RevId: 743302226
This commit is contained in:
jax authors 2025-04-02 15:33:18 -07:00
commit c8273d7795
5 changed files with 184 additions and 9 deletions

101
.github/workflows/k8s.yaml vendored Normal file
View File

@ -0,0 +1,101 @@
name: Distributed run using K8s Jobset
on:
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
defaults:
run:
shell: bash -ex -o pipefail {0}
jobs:
distributed-initialize:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4
with:
path: jax
- name: Start Minikube cluster
uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18
- name: Install K8s Jobset
run: |
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml
- name: Build image
run: |
cat > Dockerfile <<EOF
FROM ubuntu:22.04
ADD jax /opt/jax
RUN apt-get update && apt-get install -y python-is-python3 python3-pip
RUN pip install -e /opt/jax[k8s]
EOF
minikube image build -t local/jax:latest .
- name: Create service account for K8s job introspection
run: |
kubectl apply -f jax/examples/k8s/svc-acct.yaml
- name: Prepare test job
run: |
export VERSION=v4.44.3
export BINARY=yq_linux_amd64
wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq
cat jax/examples/k8s/example.yaml |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\
tee example.yaml
- name: Submit test job
run: |
kubectl apply -f example.yaml
- name: Check job status
shell: bash -e -o pipefail {0}
run: |
while true; do
status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type)
timestamp=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$timestamp] Checking job status..."
if [ "$status" == "Completed" ]; then
echo "[$timestamp] Job has completed successfully!"
exit 0
elif [ "$status" == "Failed" ]; then
echo "[$timestamp] Job has failed!"
exit 1
else
echo "[$timestamp] Job is still running. Current pod status:"
kubectl get pods --no-headers
echo "[$timestamp] Waiting for 3 seconds before checking again..."
sleep 3
fi
done
- name: Examine individual pod outputs
if: "!cancelled()"
run: |
set +x
kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do
echo "========================================"
echo "Pod $pod output:"
echo "----------------------------------------"
kubectl logs $pod
echo "========================================"
done

View File

@ -15,6 +15,7 @@ repos:
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
exclude: examples/k8s/svc-acct.yaml
- id: end-of-file-fixer
# only include python files
files: \.py$

40
examples/k8s/example.yaml Normal file
View File

@ -0,0 +1,40 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: example
spec:
replicatedJobs:
- name: workers
template:
spec:
parallelism: 2
completions: 2
backoffLimit: 0
template:
spec:
serviceAccountName: training-job-sa
restartPolicy: Never
imagePullSecrets:
- name: null
containers:
- name: main
image: PLACEHOLDER
imagePullPolicy: IfNotPresent
resources:
requests:
cpu: 900m
nvidia.com/gpu: null
limits:
cpu: 1
nvidia.com/gpu: null
command:
- python
args:
- -c
- |
import jax
jax.distributed.initialize()
print(jax.devices())
print(jax.local_devices())
assert jax.process_count() > 1
assert len(jax.devices()) > len(jax.local_devices())

View File

@ -0,0 +1,31 @@
apiVersion: v1
kind: ServiceAccount
metadata:
name: training-job-sa
namespace: default
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: pod-reader
rules:
- apiGroups: [""]
resources: ["pods"]
verbs: ["get", "list", "watch"]
- apiGroups: ["batch"]
resources: ["jobs"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: pod-reader-binding
namespace: default
subjects:
- kind: ServiceAccount
name: training-job-sa
namespace: default
roleRef:
kind: Role
name: pod-reader
apiGroup: rbac.authorization.k8s.io

View File

@ -35,15 +35,17 @@ class K8sCluster(clusters.ClusterEnv):
try:
import kubernetes as k8s # pytype: disable=import-error
except ImportError as e:
warnings.warn(textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package is "
"not installed to enable automatic bootstrapping in this "
"environment. To enable automatic boostrapping, please install "
"jax with the [k8s] extra. For example:"
" pip install jax[k8s]"
" OR"
" pip install jax[k8s,<MORE-EXTRAS...>]"
))
warnings.warn(
'\n'.join([
textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package "
"is not installed to enable automatic bootstrapping in this "
"environment. To enable automatic boostrapping, please install "
"jax with the [k8s] extra. For example:"),
" pip install jax[k8s]",
" pip install jax[k8s,<MORE-EXTRAS...>]",
])
)
return False
k8s.config.load_incluster_config()