TPU v2 slower about 30% for even RTX 2070 in pytorch-lightning. You can try if you are working with tensorflow. Or, forget free TPU. Much slower, and many bugs with poorly documentation.
1 Prepartion
Step 1 : Install Google Cloud CLI on powershell
(New-Object Net.WebClient).DownloadFile("https://dl.google.com/dl/cloudsdk/channels/rapid/GoogleCloudSDKInstaller.exe", "$env:Temp\GoogleCloudSDKInstaller.exe")
& $env:Temp\GoogleCloudSDKInstaller.exe
Optional : There is limited connection in China. I got ERROR: gcloud crashed (SSLError): HTTPSConnectionPool(host='oauth2.googleapis.com', port=443): Max retries exceeded with url
under the education networking.
So I use clash-for-windows
for gloud proxy.
gcloud config set proxy/type http
gcloud config set proxy/address 127.0.0.1
gcloud config set proxy/port 7890
Then login gcloud, confirm the redirection buttom, successed if the CLI output the You are now logged in as [MyAccount]. Your current project is [MyProject]. You can change this setting by running.
gcloud auth login
Set Google cloud project
gcloud config set project <your-project-id>
gcloud config set account <your-email-account>
Step 2 : Create TPU Virtual Machine
Check your dependencies. Verify your torch version and torch-related packages version.
pip3 list | grep torch
E.g. I use pytorch-forecasting==0.10.2
and pytorch-lightning==1.7.2
for my project. Then I got torch requirements by querying in www.pypi.org. And ues the latest one in supported torch version. Free trial GPU table here, you can find it in TRC response letter.
TPU Type | Use Type | Num | Zone |
---|---|---|---|
v2-8 | on-demand | 5 | us-central1-f |
v3-8 | on-demand | 5 | europe-west4-a |
v2-8 | preemptible | 100 | us-central1-f |
Then fill the follow command by the table. Create a TPU VM.
gcloud compute tpus tpu-vm create <your-tpu-name> --zone=us-central1-f --accelerator-type=v2-8 --version=tpu-vm-pt-1.10
Use ssh to connect the VM your have created.
gcloud compute tpus tpu-vm ssh <your-tpu-name> --zone us-central1-f --project <project-id>
Set PJRT
export PJRT_DEVICE=TPU
High-freq Commands. Save the instructions into a file or an alias.
Goal | Command |
---|---|
Stop | gcloud compute tpus tpu-vm stop your-tpu-name –zone=zone |
Start | gcloud compute tpus tpu-vm start your-tpu-name –zone zone |
Delete | gcloud compute tpus tpu-vm delete your-tpu-name –zone=zone |
List | gcloud compute tpus tpu-vm list –zone=zone |
Search | gcloud compute tpus tpu-vm describe your-tpu-name –zone=zone |
Step 3 : Use XShell to connect
Go to google cloud console, click
Compute Engie
on the right bar, thenTPU
, Copy the external IP.Then, go to the localtion that store your ssh key-pairs which are generated by putty, namely
google_compute_engine.ppk
.Use
Putty key gen
to export a private key.Create a new session on XShell with specificed
Connection->Proxy
andConnection->User Auth->Method->Pub Key
2 Use for training
My get my first free gpu with v2-8
. Although they recommend to set PJRT
env, it proves that PJRT
failed but old XRT
works with torch-1.10. The supposed env variables set below.
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
I perfer to use virtual env
to manage my projects. But there is no docs about python virtual env. And I found the tips from this github issue : The tpu-specificed packages are under /usr/share/tpu/
. However, only tensorflow packages. So I create venv by copying system packages for pytorch repos.
virtualenv -p python3.8 vtorch --system-site-packages
However, the following errors occurs
...
ImportError: cannot import name 'notf' from 'tensorboard.compat'
...
TypeError: Descriptors cannot not be created directly. If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0. If you cannot immediately regenerate your protos, some other possible workarounds are: 1. Downgrade the protobuf package to 3.20.x or lower. 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use p
ure-Python parsing and will be much slower).
...
RuntimeError: tensorflow/compiler/xla/xla_client/xrt_local_service.cc:56 : Check failed: tensorflow::NewServer(server_def, &server_) == ::tensorflow::Status::OK() (UNKNOWN: Could not start gRPC server vs. OK)
The post that suggest change jax
and jaxlib
not work. I followed the above out to downgrade protobuf
from 4.22.1 to 3.20 fix the problem.
I use pytorch-lightning for training, follow the only-tutorial, fix the probelm of lightning version and so on. Very sad.
TPU slower about 30% for even RTX 2070.