Генеративно-состязательная сеть, которую вы построите, создаёт персонажей из манги и аниме. Рисуйте вайфу в своё удовольствие!
Давно хотели создать своих Аску, Код 002 или Канеки Кена? У вас появилась отличная возможность это сделать :)
Что такое генеративно-состязательная сеть?
Лучший вывод, который может генерировать нейронная сеть, похож на человеческий. Образно генеративно-состязательная сеть (GAN) может даже обмануть человека, заставив его думать, что вывод сделан им самим.
В генеративно-состязательных сетях две сети соревнуются друг с другом, что приводит к взаимным импровизациям. Генератор обманывает дискриминатор, создавая ложные входы и выдавая их за реальные. Дискриминатор сообщает, является ли ввод реальным или ложным.
В обучении GAN есть три главных шага:
-
- Используйте генератор для создания ложных входов из случайного шума.
- Обучите дискриминатор на ложных и реальных входах (одновременно с объединением или поочерёдно, что предпочтительнее).
- Обучите всю модель: дискриминатор + генератор.
Помните, что весовые коэффициенты дискриминатора «заморожены» во время последнего шага.
Причина сочетания обоих сетей состоит в отсутствии обратной связи на выходах генератора. Единственный ориентир – если дискриминатор принимает выходы генератора.
Можно сказать, что они соперничают друг с другом. Генератор обучается во время схватки с «соперником», чтобы реализовать цель.
Наша сеть
Для задачи мы используем глубокую сверточную генеративно-состязательную сеть.
Вот несколько рекомендаций для таких сетей:
- Замените максимальные подвыборки шагами свёртки.
- Используйте перемещённую свёртку для повышения частоты дискретизации.
- Устраните полностью соединённые слои.
- Используйте пакетную нормализацию, кроме выходного слоя генератора и входного слоя дискриминатора.
- Используйте ReLU в генераторе, кроме выхода, который использует tanh.
- Leaky ReLU в дискриминаторе.
Детали сетапа
- версия Keras==2.2.4
- TensorFlow==1.8.0
- Jupyter Notebook
- Matplotlib и другие библиотеки типа NumPy, Pandas
- Python==3.5.7
Набор данных
Набор данных для лиц из аниме можно собрать на тематических сайтах, скачивая картинки и вырезая лица. Код Python, который демонстрирует это.
Также доступны обработанные и обрезанные лица.
Генератор
Он состоит из сверточных слоёв, пакетной нормализации и функции активации Leaky ReLU для повышения частоты дискретизации. Мы используем параметр шагов в сверточном слое, чтобы избежать нестабильной обучаемости GAN. Функция не будет равно нулю, если x < 0, вместо этого Leaky ReLU имеет небольшое отрицательное отклонение (0.01 и так далее).
Код
def get_gen_normal(noise_shape):
kernel_init = 'glorot_uniform'
gen_input = Input(shape = noise_shape)
generator = Conv2DTranspose(filters = 512, kernel_size = (4,4), strides = (1,1), padding = "valid", data_format = "channels_last", kernel_initializer = kernel_init)(gen_input)
generator = BatchNormalization(momentum = 0.5)(generator)
generator = LeakyReLU(0.2)(generator)
generator = Conv2DTranspose(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
generator = BatchNormalization(momentum = 0.5)(generator)
generator = LeakyReLU(0.2)(generator)
generator = Conv2DTranspose(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
generator = BatchNormalization(momentum = 0.5)(generator)
generator = LeakyReLU(0.2)(generator)
generator = Conv2DTranspose(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
generator = BatchNormalization(momentum = 0.5)(generator)
generator = LeakyReLU(0.2)(generator)
generator = Conv2D(filters = 64, kernel_size = (3,3), strides = (1,1), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
generator = BatchNormalization(momentum = 0.5)(generator)
generator = LeakyReLU(0.2)(generator)
generator = Conv2DTranspose(filters = 3, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
generator = Activation('tanh')(generator)
gen_opt = Adam(lr=0.00015, beta_1=0.5)
generator_model = Model(input = gen_input, output = generator)
generator_model.compile(loss='binary_crossentropy', optimizer=gen_opt, metrics=['accuracy'])
generator_model.summary()
return generator_model
Дискриминатор
Он также состоит из сверточных слоёв, где мы используем шаги для понижения частоты дискретизации и пакетной нормализации для стабильности.
Код
def get_disc_normal(image_shape=(64,64,3)):
dropout_prob = 0.4
kernel_init = 'glorot_uniform'
dis_input = Input(shape = image_shape)
discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input)
discriminator = LeakyReLU(0.2)(discriminator)
discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
discriminator = BatchNormalization(momentum = 0.5)(discriminator)
discriminator = LeakyReLU(0.2)(discriminator)
discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
discriminator = BatchNormalization(momentum = 0.5)(discriminator)
discriminator = LeakyReLU(0.2)(discriminator)
discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
discriminator = BatchNormalization(momentum = 0.5)(discriminator)
discriminator = LeakyReLU(0.2)(discriminator)
discriminator = Flatten()(discriminator)
discriminator = Dense(1)(discriminator)
discriminator = Activation('sigmoid')(discriminator)
dis_opt = Adam(lr=0.0002, beta_1=0.5)
discriminator_model = Model(input = dis_input, output = discriminator)
discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy'])
discriminator_model.summary()
return discriminator_model
Полная GAN
Чтобы генератор проверял выход, скомпилируем сеть в Keras. В этой сети входом будет случайный шум для генератора, а выходом – выход генератора «скормленный» дискриминатору. Во избежание «коллапса состязания» весовые коэффициенты дискриминатора заморожены.
Код
discriminator.trainable = False
opt = Adam(lr=0.00015, beta_1=0.5) #same as generator
gen_inp = Input(shape=noise_shape)
GAN_inp = generator(gen_inp)
GAN_opt = discriminator(GAN_inp)
gan = Model(input = gen_inp, output = GAN_opt)
gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics=['accuracy'])
gan.summary()
Тренировка модели
Базовая конфигурация модели
1. Сгенерируйте случайный нормальный шум для входа:
def gen_noise(batch_size, noise_shape):
return np.random.normal(0, 1, size=(batch_size,)+noise_shape)
2. Объедините реальные данные из набора с шумом:
data_X = np.concatenate([real_data_X, fake_data_X])
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2
data_Y = np.concatenate([real_data_Y, fake_data_Y])
3. Подайте шум на вход:
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2
4. Тренируем только генератор
print("Begin step: ", tot_step)
step_begin_time = time.time()
real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir)
noise = gen_noise(batch_size,noise_shape)
fake_data_X = generator.predict(noise)
if (tot_step % 100) == 0:
step_num = str(tot_step).zfill(4)
save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")
5. Тренируем только дискриминатор:
discriminator.trainable = True
generator.trainable = False
dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y)
dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y)
print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))
6. Тренируем совмещённую GAN:
generator.trainable = True
discriminator.trainable = False
GAN_X = gen_noise(batch_size,noise_shape)
GAN_Y = real_data_Y
gan_metrics = gan.train_on_batch(GAN_X,GAN_Y)
print("GAN loss: %f" % (gan_metrics[0]))
text_file = open(log_dir+"\\training_log.txt", "a")
text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0]))
text_file.close()
avg_GAN_loss.append(gan_metrics[0])
end_time = time.time()
diff_time = int(end_time - step_begin_time)
print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))
7. Сохраните экземпляры дискриминатора и генератора:
if ((tot_step+1) % 500) == 0:
print("-----------------------------------------------------------------")
print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss)))
print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss)))
print("Average GAN loss: %f" % (np.mean(avg_GAN_loss)))
print("-----------------------------------------------------------------")
discriminator.trainable = False
generator.trainable = False
# predict on fixed_noise
fixed_noise_generate = generator.predict(noise)
step_num = str(tot_step).zfill(4)
save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png")
generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5")
discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")
Результаты Манга-генератора
После 10000 шагов обучения результат выглядит круто! Смотрите сами.
Более длительная тренировка с большим набором данных приведёт к лучшим результатам. (Некоторые лица получились страшными, это правда :D)
Заключение
Наверняка задача генерации лиц в стиле аниме интересна.
Но в ней ещё есть место для улучшений: лучшее обучение, модели и набор данных. Наша модель вряд ли заставит человека задуматься, реальны ли сгенерированные модели. Тем не менее, она отлично справляется с поставленной задачей. Вы можете улучшить сеть, используя персонажей во весь рост.
Исходный код проекта доступен на GitHub.
Комментарии