TPU v2 slower about 30% for even RTX 2070 in pytorch-lightning. You can try if you are working with tensorflow. Or, forget free TPU. Much slower, and many bugs with poorly documentation.
1 Prepartion
Step 1 : Install Google Cloud CLI on powershell
(New-Object Net.WebClient).DownloadFile("https://dl.google.com/dl/cloudsdk/channels/rapid/GoogleCloudSDKInstaller.exe", "$env:Temp\GoogleCloudSDKInstaller.exe")
& $env:Temp\GoogleCloudSDKInstaller.exe
Optional : There is limited connection in China. I got ERROR: gcloud crashed (SSLError): HTTPSConnectionPool(host='oauth2.googleapis.com', port=443): Max retries exceeded with url under the education networking.
So I use clash-for-windows for gloud proxy.
gcloud config set proxy/type http
gcloud config set proxy/address 127.0.0.1
gcloud config set proxy/port 7890
Then login gcloud, confirm the redirection buttom, successed if the CLI output the You are now logged in as [MyAccount]. Your current project is [MyProject]. You can change this setting by running.
gcloud auth login
Set Google cloud project
gcloud config set project <your-project-id>
gcloud config set account <your-email-account>
Step 2 : Create TPU Virtual Machine
Check your dependencies. Verify your torch version and torch-related packages version.
pip3 list | grep torch
E.g. I use pytorch-forecasting==0.10.2 and pytorch-lightning==1.7.2 for my project. Then I got torch requirements by querying in www.pypi.org. And ues the latest one in supported torch version. Free trial GPU table here, you can find it in TRC response letter.
| TPU Type | Use Type | Num | Zone |
|---|---|---|---|
| v2-8 | on-demand | 5 | us-central1-f |
| v3-8 | on-demand | 5 | europe-west4-a |
| v2-8 | preemptible | 100 | us-central1-f |
Then fill the follow command by the table. Create a TPU VM.
gcloud compute tpus tpu-vm create <your-tpu-name> --zone=us-central1-f --accelerator-type=v2-8 --version=tpu-vm-pt-1.10
Use ssh to connect the VM your have created.
gcloud compute tpus tpu-vm ssh <your-tpu-name> --zone us-central1-f --project <project-id>
Set PJRT
export PJRT_DEVICE=TPU
High-freq Commands. Save the instructions into a file or an alias.
| Goal | Command |
|---|---|
| Stop | gcloud compute tpus tpu-vm stop your-tpu-name –zone=zone |
| Start | gcloud compute tpus tpu-vm start your-tpu-name –zone zone |
| Delete | gcloud compute tpus tpu-vm delete your-tpu-name –zone=zone |
| List | gcloud compute tpus tpu-vm list –zone=zone |
| Search | gcloud compute tpus tpu-vm describe your-tpu-name –zone=zone |
Step 3 : Use XShell to connect
Go to google cloud console, click
Compute Engieon the right bar, thenTPU, Copy the external IP.Then, go to the localtion that store your ssh key-pairs which are generated by putty, namely
google_compute_engine.ppk.Use
Putty key gento export a private key.Create a new session on XShell with specificed
Connection->ProxyandConnection->User Auth->Method->Pub Key
2 Use for training
My get my first free gpu with v2-8. Although they recommend to set PJRT env, it proves that PJRT failed but old XRT works with torch-1.10. The supposed env variables set below.
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
I perfer to use virtual env to manage my projects. But there is no docs about python virtual env. And I found the tips from this github issue : The tpu-specificed packages are under /usr/share/tpu/. However, only tensorflow packages. So I create venv by copying system packages for pytorch repos.
virtualenv -p python3.8 vtorch --system-site-packages
However, the following errors occurs
...
ImportError: cannot import name 'notf' from 'tensorboard.compat'
...
TypeError: Descriptors cannot not be created directly. If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0. If you cannot immediately regenerate your protos, some other possible workarounds are: 1. Downgrade the protobuf package to 3.20.x or lower. 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use p
ure-Python parsing and will be much slower).
...
RuntimeError: tensorflow/compiler/xla/xla_client/xrt_local_service.cc:56 : Check failed: tensorflow::NewServer(server_def, &server_) == ::tensorflow::Status::OK() (UNKNOWN: Could not start gRPC server vs. OK)
The post that suggest change jax and jaxlib not work. I followed the above out to downgrade protobuf from 4.22.1 to 3.20 fix the problem.
I use pytorch-lightning for training, follow the only-tutorial, fix the probelm of lightning version and so on. Very sad.
TPU slower about 30% for even RTX 2070.
