February 13, 09:02

PyTorch NLP best practices

Very simple ideas, actually.

(1) Multi GPU parallelization and FP16 training

Do not bother reinventing the wheel.

Just use nvidia's apex, DistributedDataParallel, DataParallel.

Best examples [here](github.com/huggingface/pytorch-pretrained-BERT).

(2) Put as much as possible INSIDE of the model

Implement the as much as possible of your logic inside of nn.module.

Why?

So that you can seamleassly you all the abstractions from (1) with ease.

Also models are more abstract and reusable in general.

(3) Why have a separate train/val loop?

PyTorch 0.4 introduced context handlers.

You can simplify your train / val / test loops, and merge them into one simple function.

context = torch.no_grad() if loop_type=='Val' else torch.enable_grad()

if loop_type=='Train':
model.train()
elif loop_type=='Val':
model.eval()

with context:
for i, (some_tensor) in enumerate(tqdm(train_loader)):
# do your stuff here
pass

(4) EmbeddingBag

Use EmbeddingBag layer for morphologically rich languages. Seriously!

(5) Writing trainers / training abstractions

This is waste of time imho if you follow (1), (2) and (3).

(6) Nice bonus

If you follow most of these, you can train on as many GPUs and machines as you wan for any language)

(7) Using tensorboard for logging

This goes without saying.

#nlp

#deep_learning

huggingface/pytorch-pretrained-BERT

📖The Big-&-Extending-Repository-of-Transformers: Pretrained PyTorch models for Google's BERT, OpenAI GPT & GPT-2, Google/CMU Transformer-XL. - huggingface/pytorch-pretrained-BERT


PyTorch DataLoader, GIL thrashing and CNNs

Well all of this seems a bit like magic to me, but hear me out.

I abused my GPU box for weeks running CNNs on 2-4 GPUs.

Nothing broke.

And then my GPU box started shutting down for no apparent reason.

No, this was not:

- CPU overheating (I have a massive cooler, I checked - it works);

- PSU;

- Overclocking;

- It also adds to confusion that AMD has weird temperature readings;

To cut the story short - if you have a very fast Dataset class and you use PyTorch's DataLoader with workers > 0 it can lead to system instability instead of speeding up.

It is obvious in retrospect, but it is not when you face this issue.

#deep_learning

#pytorch