【GAN】CycleGAN学习--流程讲解
本博客讲解代码网址: 官方源码: 官方源码和本博客讲解代码思路一致,本篇博客主要讲解整个流程。 但如果研究的话,推荐研究官方源码,其实也比较简单。
训练过程
1. Train Generators
loss函数:
loss_identity = (loss_id_A + loss_id_B) / 2 loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
代码:
# ------------------ # Train Generators # ------------------ optimizer_G.zero_grad() # Identity loss loss_id_A = criterion_identity(G_BA(real_A), real_A) loss_id_B = criterion_identity(G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss fake_B = G_AB(real_A) loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) fake_A = G_BA(real_B) loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle loss recov_A = G_BA(fake_B) loss_cycle_A = criterion_cycle(recov_A, real_A) recov_B = G_AB(fake_A) loss_cycle_B = criterion_cycle(recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity loss_G.backward() optimizer_G.step()
2. Train Discriminator A
loss 函数:
loss_D_A = (loss_real + loss_fake) / 2
代码:
# ----------------------- # Train Discriminator A # ----------------------- optimizer_D_A.zero_grad() # Real loss loss_real = criterion_GAN(D_A(real_A), valid) # Fake loss (on batch of previously generated samples) fake_A_ = fake_A_buffer.push_and_pop(fake_A) loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) # Total loss loss_D_A = (loss_real + loss_fake) / 2 loss_D_A.backward() optimizer_D_A.step()
3. Train Discriminator B
loss函数:
loss_D_B = (loss_real + loss_fake) / 2
代码:
# ----------------------- # Train Discriminator B # ----------------------- optimizer_D_B.zero_grad() # Real loss loss_real = criterion_GAN(D_B(real_B), valid) # Fake loss (on batch of previously generated samples) fake_B_ = fake_B_buffer.push_and_pop(fake_B) loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) # Total loss loss_D_B = (loss_real + loss_fake) / 2 loss_D_B.backward() optimizer_D_B.step() loss_D = (loss_D_A + loss_D_B) / 2
上一篇:
通过多线程提高代码的执行效率例子
下一篇:
推荐几款实现内网穿透的实用工具(转载)