rocm_jax/ci/envs/default.env
Nitin Srinivasan 031c0acf50 Add new CI scripts for running Pytests
This commit adds the new CI scripts for running Pytests. It makes use of the pytest envs inside the "ci/envs/run_tests" folder to control the build behavior. For e.g: for running the GPU tests with Pytest, we will need to run `./ci/run_pytest.sh ./ci/envs/run_tests/pytest_gpu.env`. Note that Pytests need JAX wheels to be installed on the system to work. The `install_wheels_locally.sh` script installs these wheels in CI builds.

PiperOrigin-RevId: 701331411
2024-11-29 12:08:17 -08:00

70 lines
3.3 KiB
Bash

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This file contains all the default values for the "JAXCI_" environment
# variables used in the CI scripts. These variables are used to control the
# behavior of the CI scripts such as the Python version used, path to JAX/XLA
# repo, if to clone XLA repo, etc.
# The path to the JAX git repository.
export JAXCI_JAX_GIT_DIR=$(pwd)
# Controls the version of Hermetic Python to use. Use system default if not
# set.
export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')}
# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local
# copy of XLA instead of the pinned version in the WORKSPACE. When
# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically.
export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-}
# If set to 1, the builds will clone the XLA repository at HEAD and set its
# path in JAXCI_XLA_GIT_DIR.
export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0}
# Allows overriding the XLA commit that is used.
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-}
# Controls the location where the artifacts are written to.
export JAXCI_OUTPUT_DIR="$(pwd)/dist"
# When enabled, artifacts will be built with RBE. Requires gcloud authentication
# and only certain platforms support RBE. Therefore, this flag is enabled only
# for CI builds where RBE is supported.
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
# #############################################################################
# Test script specific environment variables.
# #############################################################################
# The maximum number of tests to run per GPU when running single accelerator
# tests with parallel execution with Bazel. The GPU limit is set because we
# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we
# use L4 machines which have 24GB of RAM but can be overriden if we use a
# different GPU type.
export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12}
# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override
# this value in the Github action workflow files.
export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0}
# Pytest specific environment variables below. Used in run_pytest_*.sh scripts.
# Sets the number of TPU cores for the TPU machine type. These values are
# defined in the TPU GitHub Actions workflow.
export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
# on the system. By default, it is set to match the version of the hermetic
# Python used by Bazel for building the wheels.
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}