29 Commits

Author SHA1 Message Date
jax authors
582c56a707 Change determination of cloud TPU to check for TPU chips.
This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend.

PiperOrigin-RevId: 630055511
2024-05-02 07:22:56 -07:00
David Hall
6a891f2cd9 set tensorstore settings in cloud_tpu that make serialization more robust 2024-04-26 14:19:06 -07:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Shiva Shahrokhi
65f3e4fffd making sure enhanced barrier only turns on when there is a supported TPU available. 2024-01-12 23:47:37 +00:00
Jieying Luo
b81a3e1fd7 Remove calling configure_library_path during jax import and get libtpu path from libtpu_module.get_library_path().
PiperOrigin-RevId: 572306461
2023-10-10 10:59:37 -07:00
Skye Wanderman-Milne
9d1cbc7d21 Default to PJRT TPU runtime instead of StreamExecutor on older jaxlibs.
I messed up the forwards compat in
3e50fea29e. The
next jaxlib release won't need the env var at all, but jaxlib 0.4.14
and older still do.
2023-08-23 00:06:16 +00:00
Skye Wanderman-Milne
3e50fea29e Remove option to use StreamExecutor Cloud TPU client in JAX
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023).

PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Skye Wanderman-Milne
71cbfe49f0 Issue warning if JAX_USE_PJRT_C_API_ON_TPU=false
```
$ JAX_USE_PJRT_C_API_ON_TPU=0 python3 -c "import jax"
/home/skyewm/jax/jax/_src/cloud_tpu_init.py:77: UserWarning: JAX_USE_PJRT_C_API_ON_TPU=0 will no longer be supported in an upcoming future release. Please file an issue at https://github.com/google/jax/issues if you need this setting.
  warnings.warn(
```
2023-06-20 14:27:08 -07:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Peter Hawkins
7bfd89a89c Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
PiperOrigin-RevId: 515420193
2023-03-09 13:11:04 -08:00
jax authors
e6e513a6e9 Add environment variable check in additional to libtpu check for cloud tpu vm
PiperOrigin-RevId: 504588621
2023-01-25 09:51:37 -08:00
Skye Wanderman-Milne
004b8c1a09 Don't set TPU topology env vars in cloud_tpu_init.py
This used to be necessary. However, now these are automatically set in
libtpu. Beyond being redundant, the Python logic needs to be updated
to avoid getting KeyErrors on new topologies and TPU versions, so
better to remove it.

This also moves `get_metadata` to cloud_tpu_cluster.py since it's only
used in that file now.
2023-01-13 22:02:16 +00:00
Ran Ran
78a7e161bb Set JAX_PLATFORMS=tpu,cpu on TPUs 2022-10-05 04:39:04 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Shiva Shahrokhi
88f1b9fae7 adding os env to track JAX platform 2022-06-14 21:21:34 +00:00
Yash Katariya
548a6bf58b * Make all arguments to distributed.initialize equal to None.
* On Cloud TPUs, figure out the coordinator address automatically.

PiperOrigin-RevId: 449261786
2022-05-17 10:53:54 -07:00
David Silverstone
4a631886c6 Fixed a typo in cloud_tpu_init.py 2022-05-11 21:19:28 +00:00
jax authors
96173913f0 Add retry logic to cloud_tpu_init.py when API response fails
PiperOrigin-RevId: 447509510
2022-05-09 10:46:10 -07:00
Gain Hagenau
59d8b8d6b2 Remove flags set for all v4 TPUs. Topology flags will now be set in libTPU.
Remove deprecated fields `TPU_MESH_CONTROLLER_ADDRESS` and `TPU_MESH_CONTROLLER_PORT`.

PiperOrigin-RevId: 442663216
2022-04-18 16:39:34 -07:00
Skye Wanderman-Milne
b36cd2553d Add TPU v4 support to cloud_tpu_init
Tested manually on a v4-8 and v4-32.
2022-03-15 16:00:16 -07:00
Skye Wanderman-Milne
b2fd6a772b Changes to make jax[tpu] work better in a docker container.
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
   looking for the libtpu Python package, instead of /lib/libtpu.so
   (which isn't necessarily present in a docker container). JAX now
   relies on the libtpu package instead of the system libtpu.so, so
   this makes more sense either way. This means we'll try/catch an
   ImportError in all non-TPU environments when importing jax, which
   hopefully isn't noticeably slow.

2. Add requests as a jax[tpu] dependency, since it's needed by
   cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
   may not be installed in docker containers, virtualenvs, etc.

I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```

And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```

Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
2021-07-12 17:42:46 -07:00
Skye Wanderman-Milne
ba972f0207 On Cloud TPU, use pip-installed libtpu instead of system default if applicable. 2021-06-22 23:51:57 +00:00
Skye Wanderman-Milne
6722c14589 Add v3-64 config to automatic Cloud TPU pod slice initialization. 2021-04-19 15:49:40 -07:00
Skye Wanderman-Milne
5cb5056ea7 Suppress gRPC log spam on Cloud TPU. 2021-03-11 22:52:54 +00:00
Skye Wanderman-Milne
c32d1e5aae Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-10 09:15:31 -08:00
Skye Wanderman-Milne
902038a718 Revert breaking change:
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.

This removes the need to manually set these env vars when running on a Cloud TPU pod slice.

PiperOrigin-RevId: 361681134
2021-03-08 16:10:54 -08:00
Skye Wanderman-Milne
5a2859e1b6 Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-08 12:04:32 -08:00