Addestramento di Resnet50 su Cloud TPU con PyTorch

Questo tutorial mostra come addestrare il modello ResNet-50 su un dispositivo Cloud TPU con PyTorch. Puoi applicare lo stesso pattern ad altri modelli di classificazione delle immagini ottimizzati per TPU che utilizzano PyTorch e il set di dati ImageNet.

Il modello in questo tutorial si basa su Deep Residual Learning for Image Recognition, che introduce per la prima volta l'architettura della rete residuale (ResNet). Il tutorial utilizza la variante a 50 livelli, ResNet-50, e mostra l'addestramento del modello utilizzando PyTorch/XLA.

Crea una VM TPU

  1. Apri una finestra di Cloud Shell.

    Apri Cloud Shell

  2. Crea una VM TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. Connettiti alla VM TPU utilizzando SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Installa PyTorch/XLA sulla VM TPU:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clona il repository GitHub PyTorch/XLA

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Esegui lo script di addestramento con dati falsi

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1