Batch size for real vs fake images

I was trying to add WGAN_GP loss function into my anime faces gan as model was falling apart (real score was going to 1).
I found this github that implements it in pytorch:

In there was a step of adding the real image with generated image together generate an interpolated image.
But I noticed that on the last batch the real_images(from dataload) only has 77 images while the fake images has 128 images (since generated images come from random pixels) and it resulted in this error:

in gradient_penalty(discriminator, real_data, generated_data, gp_weight)
9 alpha = alpha.expand_as(real_data)
10 alpha = to_device(alpha, device)
—> 11 interpolated = alpha * + (1 - alpha) *
12 interpolated = Variable(interpolated, requires_grad=True)
13 interpolated = to_device(interpolated, device)

RuntimeError: The size of tensor a (77) must match the size of tensor b (128) at non-singleton dimension 0

I think for this scenario I should get batch size based on size of real_images coming from data loader? (instead of using the batch size variable hyperparameter)

Also just wondering, will the different batch sizes between real and generated images cause issues in the model?


Quick fix

DataLoader in pytorch has an additional argument when creating: drop_last. It literally drops last batch if it’s size would be smaller than the desired batch size.

If you add shuffling of the training dataset, then you don’t have to worry about never seen examples, because the examples from this batch get moved around when shuffling.

No leftovers

You can simply generate a batch with a given number of fake images as well. Since you probably generate your fake_data somewhere close to real_data you can force it to have 77 images as well.
Hint: torch tensors have a size() method. It accepts an argument specifying the index of the dimension.

1 Like

Thanks Sebastian.
I took the “No Leftovers” approach and overwrote the batch size when I generate the latent tensor and it worked.

actual_batch_size = real_images.size()[0]
latent = torch.randn(actual_batch_size, latent_size, 1, 1, device=device)